import {
    Material,
    RenderTarget,
    WebGLRenderTarget
} from "three";
import {Api} from "../../Api.js";
import {RenderingManager} from "../RenderingManager.js";
// @ts-ignore
import {Renderer} from "three/examples/jsm/renderers/common/Renderer.js";
import {Web3DMeshNodeMaterial} from "./Web3DMeshNodeMaterial.js";
import MeshBasicNodeMaterial from "three/examples/jsm/nodes/materials/MeshBasicNodeMaterial.js";
import * as Nodes from "three/examples/jsm/nodes/Nodes.js";
import {EdgesAndSsaoRenderPass} from "../EdgesAndSsaoRenderPass.js";

export class WebGPUEdgesAndSsaoRenderPass extends EdgesAndSsaoRenderPass {

    constructor(api: Api, renderingManager: RenderingManager) {
        super(api, renderingManager);
    }

    get renderEdges(): boolean {
        return this.api.settingsDispatcher.settings.renderEdges &&
            !this.renderingManager.xr.isStarted;
    }

    createEdgesMaterial(): MeshBasicNodeMaterial {
        const material = new MeshBasicNodeMaterial();

        const uniforms = [
            // @ts-ignore
            Nodes.uniform(Nodes.vec2(this.cameraNearFar)).label('cameraNearFar'),
            // @ts-ignore
            Nodes.uniform(Nodes.vec2(this.texelSize)).label('texelSize'),
            // @ts-ignore
            Nodes.uniform(Nodes.mat4(this.api.camera.projectionMatrixInverse)).label('uProjectionInverse'),
        ];

        // @ts-ignore
        const colorNode = Nodes.wgslFn(edgeShader, uniforms);
        const colorNodeParameters = {
            vUv: Nodes.uv(),
            inputBuffer: Nodes.texture(this.colorTexture),
            inputBuffer_sampler: Nodes.texture(this.colorTexture),
            normalBuffer: Nodes.texture(this.normalTexture),
            normalBuffer_sampler: Nodes.texture(this.normalTexture),
            idAndCutsBuffer: Nodes.texture(this.idAndCutsTexture),
            idAndCutsBuffer_sampler: Nodes.texture(this.idAndCutsTexture),
            ssaoBuffer: Nodes.texture(this.ssaoRenderTarget.texture),
            ssaoBuffer_sampler: Nodes.texture(this.ssaoRenderTarget.texture),
            // @ts-ignore
            hack: Nodes.texture(this.depthTexture).label('depthBuffer'), // NB: This parameter is required by the Three.js node system to pass the actual depth buffer to the shader. The parameter gets converted to a depth sample (float). The label, however, allows us to access the entire depth buffer by that identifier. TODO: Do this without hacks when supported by Three.js.
            hack_sampler: Nodes.texture(this.depthTexture),
        };

        material.transparent = true;
        material.depthTest = true;
        material.depthWrite = true;
        material.premultipliedAlpha = true;
        material.extensions.fragDepth = true;   // TODO: Probably not relevant for WebGPU

        // @ts-ignore
        material.outputNode = Nodes.outputStruct(
            colorNode(colorNodeParameters),
        );

        return material;
    }

    createSsaoMaterial(simple: boolean): MeshBasicNodeMaterial {
        const material = new MeshBasicNodeMaterial();

        // @ts-ignore
        const seedUniform = Nodes.uniform(Nodes.float(0)).label('seed');
        const uniforms = [
            seedUniform,
            // @ts-ignore
            Nodes.uniform(Nodes.vec2(this.texelSize)).label('texelSize'),
            // @ts-ignore
            Nodes.uniform(Nodes.mat4(this.api.camera.projectionMatrixInverse)).label('uProjectionInverse'),
        ];

        material.uniforms.seed = seedUniform;

        // @ts-ignore
        const colorNode = Nodes.wgslFn(ssaoShader, uniforms);
        const colorNodeParameters = {
            vPosition: Nodes.positionView,
            vUv: Nodes.uv(),
            resolution: Nodes.viewportResolution,
            normalBuffer: Nodes.texture(this.normalTexture),
            normalBuffer_sampler: Nodes.texture(this.normalTexture),
            noiseBuffer: Nodes.texture(this.noiseTexture),
            // @ts-ignore
            hack: Nodes.texture(this.depthTexture).label('depthBuffer'), // TODO: Avoid this hack when supported by three.js (see equivalent hack for edge line material)
            hack_sampler: Nodes.texture(this.depthTexture),
        };

        material.depthTest = false;
        material.depthWrite = false;

        material.colorNode = colorNode(colorNodeParameters);

        return material;
    }

    protected renderWithEdges(renderer: Renderer, renderTarget: RenderTarget, renderSsao: boolean, shouldRenderMaterial: (m: Material) => boolean): void {
        let shouldRenderAnything = false;
        // TODO: pass shouldRenderMaterial to web3d/src/Rendering/WebGLRenderer instead of traversing, should be faster
        this.renderingManager.traverseMaterials((o, m) => {
            const srm = shouldRenderMaterial(m);
            shouldRenderAnything = shouldRenderAnything || srm;
            m.visible = !!(m as Web3DMeshNodeMaterial).isWeb3DMeshNodeMaterial && srm;
        });
        if (shouldRenderAnything) { // early optimization
            renderer.setRenderTarget(this.multiRenderTarget);
            renderer.setClearAlpha(0);

            renderer.autoClear = true;
            renderer.autoClearColor = true;
            renderer.autoClearDepth = true;

            renderer.render(this.api.scene, this.api.camera);

            renderer.autoClear = false;
            renderer.autoClearColor = false;
            renderer.autoClearDepth = false;

            this.cameraNearFar.set(this.api.camera.near, this.api.camera.far);

            this.renderSsao(renderer, renderSsao);
            this.edgesPass.render(renderer, renderTarget as WebGLRenderTarget);

            this.renderingManager.traverseMaterials((o, m) => m.visible = !(m as Web3DMeshNodeMaterial).isWeb3DMeshNodeMaterial && shouldRenderMaterial(m));
            renderer.setRenderTarget(renderTarget);
            renderer.render(this.api.scene, this.api.camera);
        }
        this.renderingManager.traverseMaterials((o, m) => m.visible = true);
    }

    private renderSsao(renderer: Renderer, renderSsao: boolean): void {
        if (!renderSsao || !this.api.settingsDispatcher.settings.ssao) {
            // TODO: Clear buffer
        }
        else {
            const simple = !this.renderingManager.fullRender && this.api.settingsDispatcher.settings.progressiveRendering &&
                performance.now() - this.renderingManager.renderer.renderStartTime > 10; // render full ssao if we have time for it
            const rt = simple ? this.ssaoSimpleRenderTarget : this.ssaoRenderTarget;
            const pass = simple ? this.ssaoSimplePass : this.ssaoPass;

            renderer.setRenderTarget(rt as any);
            renderer.setClearAlpha(0);
            renderer.clearColor();

            // make noise dynamic with time value change only if camera is moved
            if (!this.prevCameraMatrixWorld.equals(this.api.camera.matrixWorld)) {
                // @ts-ignore
                pass.getFullscreenMaterial().uniforms.seed.node.value = performance.now() % 100;
            }
            this.prevCameraMatrixWorld.copy(this.api.camera.matrixWorld);

            pass.render(renderer, rt as any);
            // TODO: Set buffer
        }
    }
}

// language=wgsl
const commonShaderFunctions = `
    // Camera relative position
    fn computePosition(coord: vec2<f32>, depthSample: f32) -> vec3<f32> {
        var ndc: vec4<f32> = vec4((vec3(coord.x, 1.0 - coord.y, depthSample) - 0.5) * 2.0, 1.0);
        var clip: vec4<f32> = NodeUniforms.uProjectionInverse * ndc;
        return clip.xyz / clip.w;
    }
    
    // Treat background as surface with normal towards camera
    fn getNormal(normalSample: vec4<f32>) -> vec3<f32> {
        if (normalSample.x == 0.0 && normalSample.y == 0.0 && normalSample.z == 0.0) {
            return vec3(0.0, 0.0, 1.0);
        } else {
            return normalSample.xyz;
        }
    }
    
    fn getOffsetUvs(texelSize: vec2<f32>) -> array<vec2<f32>, 4> {
        var offsetUvs: array<vec2<f32>, 4>;
        offsetUvs[0] = vec2(texelSize.x, 0);
        offsetUvs[1] = vec2(-texelSize.x, 0);
        offsetUvs[2] = vec2(0, texelSize.y);
        offsetUvs[3] = vec2(0, -texelSize.y);
        return offsetUvs;
    }
`;

// language=wgsl
const edgeShader = `
    fn mainColor(
        vUv: vec2<f32>,
        inputBuffer: texture_2d<f32>, inputBuffer_sampler: sampler,
        normalBuffer: texture_2d<f32>, normalBuffer_sampler: sampler,
        idAndCutsBuffer: texture_2d<f32>, idAndCutsBuffer_sampler: sampler,
        ssaoBuffer: texture_2d<f32>, ssaoBuffer_sampler: sampler,
        hack: f32
    ) -> vec4<f32> {
        var correctedUv: vec2<f32> = vec2(vUv.x, 1.0 - vUv.y);
        var color: vec4<f32> = textureSample(inputBuffer, inputBuffer_sampler, correctedUv);
        var normalSample: vec4<f32> = textureSample(normalBuffer, normalBuffer_sampler, correctedUv);
        var normal: vec3<f32> = getNormal(normalSample);
        var idAndCutsSample = textureSample(idAndCutsBuffer, idAndCutsBuffer_sampler, correctedUv);
        var id = idAndCutsSample.r;
        var depthSample: f32 = textureSample(depthBuffer, depthBuffer_sampler, correctedUv);
        fragDepthPlaceholder = depthSample;
        var isBackground: bool = depthSample == 1.0;
        depthSample = (depthSample + 1.0) * 0.5; // Depth correction for WebGPU coordinates
        var position = computePosition(correctedUv, depthSample);
        var planeConstant: f32 = -dot(position, normal);
        var isEdge: bool = false;
        
        var offsetUvs: array<vec2<f32>, 4> = getOffsetUvs(NodeUniforms.texelSize);
        
        for (var i: i32 = 0; i < 4; i++) {
            var uv: vec2<f32> = correctedUv + offsetUvs[i];
            var offsetNormalSample: vec4<f32> = textureSample(normalBuffer, normalBuffer_sampler, uv);
            var offsetNormal: vec3<f32> = getNormal(offsetNormalSample);
            var offsetDepth: f32 = textureSample(depthBuffer, depthBuffer_sampler, uv);
            offsetDepth = (offsetDepth + 1.0) * 0.5; // Depth correction for WebGPU coordinates
            var offsetPosition: vec3<f32> = computePosition(uv, offsetDepth);
            var offsetIdAndCutsSample: vec4<f32> = textureSample(idAndCutsBuffer, idAndCutsBuffer_sampler, uv);
            var offsetId: f32 = offsetIdAndCutsSample.r;
            
            var normalDot: f32 = abs(dot(normal, offsetNormal));
            
            var depthDelta = distancePointToPlane(normal, planeConstant, offsetPosition);
            isEdge = isEdge ||
                abs(id - offsetId) > 0.99 ||            // edges between touching entities 
                normalDot < 0.9 ||                      // edges on corners
                (depthDelta > 0.002 &&                  // edges between parallel planes on different levels
                abs(depthDelta / position.z) > 0.0005); // depth precision fix, removes noise
        }
                
        var occlusion: f32 = textureSample(ssaoBuffer, ssaoBuffer_sampler, correctedUv).a;
        color = vec4(color.rgb * (1.0 - occlusion), color.a);
        
        var fadeArtifacts: f32 = 1.0;
        if (!isBackground) {
            fadeArtifacts = clamp((0.99987 - (depthSample + 1.0) / 2.0) * 10000.0, 0.0, 1.0); // Fade long distance precision artifacs
        }
        
        var weight: f32 = 0.0;
        if (isEdge) { weight = fadeArtifacts; }
        var outputColor: vec4<f32> = mix(color, vec4(color.rgb * 0.7, max(color.a, 0.4)), weight);
        
        return outputColor;
    }
    
    var<private> fragDepthPlaceholder: f32;
    
    fn distancePointToPlane(planeNormal: vec3<f32>, planeConstant: f32, pointPos: vec3<f32>) -> f32 {
        return abs(dot(planeNormal, pointPos) + planeConstant);
    }
    
    ${ commonShaderFunctions }
`;

// language=wgsl
const ssaoShader = `
    fn mainColor(
        vPosition: vec4<f32>,
        vUv: vec2<f32>,
        resolution: vec2<f32>,
        normalBuffer: texture_2d<f32>, normalBuffer_sampler: sampler,
        noiseBuffer: texture_2d<f32>,
        hack: f32
    ) -> vec4<f32> {
        var correctedUv: vec2<f32> = vec2(vUv.x, 1.0 - vUv.y);
        var normalSample: vec4<f32> = textureSample(normalBuffer, normalBuffer_sampler, correctedUv);
        var normal: vec3<f32> = getNormal(normalSample);
        var depthSample: f32 = textureSample(depthBuffer, depthBuffer_sampler, correctedUv);
        if (depthSample == 1.0) { discard; } // background
        
        var position = computePosition(correctedUv, depthSample);
        var fragCoord = 0.5 * (vPosition.xy + 1.0) * resolution;
        const noiseBufferSize: f32 = 100;
        var coords: vec2<u32> = vec2<u32>(modf((fragCoord + vec2(NodeUniforms.seed)) / noiseBufferSize).fract * noiseBufferSize);
        var rand: vec2<f32> = normalize(textureLoad(noiseBuffer, coords, 0).xy);
        
        const occlusionRadius: f32 = 32.0;
        var occlusion: f32 = 0.0;
        
        var offsetCount: i32 = 4;
        var offsetUvs: array<vec2<f32>, 4> = getOffsetUvs(NodeUniforms.texelSize);
        
        // TODO: Simple version
        
        for (var i: i32 = 0; i < offsetCount; i++) {
            var k1: vec2<f32> = reflect(offsetUvs[i], rand);
            var k2: vec2<f32> = vec2(k1.x * SIN45 - k1.y * SIN45, k1.x * SIN45 + k1.y * SIN45);
            k1 *= occlusionRadius;
            k2 *= occlusionRadius;
            
            occlusion += getOcclusion(position, normal, correctedUv + k1);
            occlusion += getOcclusion(position, normal, correctedUv + k2 * 0.75);
            occlusion += getOcclusion(position, normal, correctedUv + k1 * 0.5);
            occlusion += getOcclusion(position, normal, correctedUv + k2 * 0.25);
        }
        
        occlusion = clamp(occlusion / f32(4 * offsetCount), 0.0, 1.0);
        return vec4(occlusion);
    }
    
    const SIN45 = 0.707107;
    
    fn getOcclusion(position: vec3<f32>, normal: vec3<f32>, uv: vec2<f32>) -> f32 {
        const uBias: f32 = 0.04;
        const uAttenuation: vec2<f32> = vec2(1.0, 1.0);

        var offsetPosition: vec3<f32> = computePosition(uv, textureSample(depthBuffer, depthBuffer_sampler, uv));
        var positionVec: vec3<f32> = offsetPosition - position;
        var intensity: f32 = max(dot(normal, normalize(positionVec)) - uBias, 0.0);
        var attenuation: f32 = 1.0 / (uAttenuation.x + uAttenuation.y * length(positionVec));
        return intensity * attenuation;
    }
    
    ${ commonShaderFunctions }
`;
