import {
    Material,
    RenderTarget,
    WebGLRenderTarget
} from "three";
import {Api} from "../../Api.js";
import {RenderingManager} from "../RenderingManager.js";
import {Web3DMeshNodeMaterial} from "./Web3DMeshNodeMaterial.js";
import {EdgesAndSsaoRenderPass} from "../EdgesAndSsaoRenderPass.js";
import {
    WebGPURenderer,
    MeshBasicNodeMaterial,
    ShaderNodeObject,
    UniformNode,
    Node,
    wgslFn,
    uniform,
    float,
    vec2,
    mat4,
    uv,
    texture,
    outputStruct,
    positionView,
    viewportResolution
} from "../../Three.WebGPU.js";

export class WebGPUEdgesAndSsaoRenderPass extends EdgesAndSsaoRenderPass {

    declare private seedUniform: ShaderNodeObject<UniformNode<Node>>;

    constructor(api: Api, renderingManager: RenderingManager) {
        super(api, renderingManager);
    }

    override init(): void {
        this.seedUniform = uniform(float(0)).label("seed");
        super.init();
    }

    createEdgesMaterial(): MeshBasicNodeMaterial {
        const material = new MeshBasicNodeMaterial();

        const uniforms = [
            uniform(vec2(this.cameraNearFar)).label("cameraNearFar"),
            uniform(vec2(this.texelSize)).label("texelSize"),
            uniform(mat4(this.api.camera.projectionMatrixInverse)).label("uProjectionInverse"),
            texture(this.colorTexture).label("inputBuffer"),
            texture(this.normalTexture).label("normalBuffer"),
            texture(this.flagsAndCutsTexture).label("flagsAndCutsBuffer"),
            texture(this.ssaoRenderTarget.texture).label("ssaoBuffer"),
            texture(this.depthTexture).label("depthBuffer"),
        ];

        const colorNode = wgslFn(edgeShader, uniforms);

        material.transparent = true;
        material.depthTest = true;
        material.depthWrite = true;
        material.premultipliedAlpha = true;

        // The presence of a depth node allows access to custom depth writing.
        // The node just returns 0 here, but we can later write our actual value into @builtin(frag_depth).
        material.depthNode = float(0);

        // We use an output struct because it will also be populated with @builtin(frag_depth), used to write depth:
        material.outputNode = outputStruct(
            colorNode({ vUv: uv() }),
        );

        return material;
    }

    createSsaoMaterial(simple: boolean): MeshBasicNodeMaterial {
        const material = new MeshBasicNodeMaterial();

        const uniforms = [
            this.seedUniform,
            uniform(vec2(this.texelSize)).label("texelSize"),
            uniform(mat4(this.api.camera.projectionMatrixInverse)).label("uProjectionInverse"),
            texture(this.normalTexture).label("normalBuffer"),
            texture(this.noiseTexture).label("noiseBuffer"),
            texture(this.depthTexture).label("depthBuffer"),
        ];

        const colorNode = wgslFn(ssaoShader, uniforms);
        const colorNodeParameters = {
            vPosition: positionView,
            vUv: uv(),
            resolution: viewportResolution,
        };

        material.depthTest = false;
        material.depthWrite = false;

        material.colorNode = colorNode(colorNodeParameters);

        return material;
    }

    protected renderWithEdges(renderer: WebGPURenderer, renderTarget: RenderTarget, renderSsao: boolean, shouldRenderMaterial: (m: Material) => boolean): void {
        let shouldRenderAnything = false;
        // TODO: pass shouldRenderMaterial to web3d/src/Rendering/WebGLRenderer instead of traversing, should be faster
        this.renderingManager.traverseMaterials((o, m) => {
            const srm = shouldRenderMaterial(m);
            shouldRenderAnything = shouldRenderAnything || srm;
            m.visible = !!(m as Web3DMeshNodeMaterial).isWeb3DMeshNodeMaterial && srm;
        });
        if (shouldRenderAnything) { // early optimization
            renderer.setRenderTarget(this.multiRenderTarget);
            renderer.setClearAlpha(0);

            renderer.autoClear = true;
            renderer.autoClearColor = true;
            renderer.autoClearDepth = true;

            renderer.render(this.api.scene, this.api.camera);

            renderer.autoClear = false;
            renderer.autoClearColor = false;
            renderer.autoClearDepth = false;

            this.cameraNearFar.set(this.api.camera.near, this.api.camera.far);

            this.renderSsao(renderer, renderSsao);
            // @ts-ignore
            this.edgesPass.render(renderer, renderTarget as WebGLRenderTarget);

            this.renderingManager.traverseMaterials((o, m) => m.visible = !(m as Web3DMeshNodeMaterial).isWeb3DMeshNodeMaterial && shouldRenderMaterial(m));
            renderer.setRenderTarget(renderTarget);
            renderer.render(this.api.scene, this.api.camera);
        }
        this.renderingManager.traverseMaterials((o, m) => m.visible = true);
    }

    private renderSsao(renderer: WebGPURenderer, renderSsao: boolean): void {
        if (!renderSsao || !this.api.settingsDispatcher.settings.ssao) {
            // TODO: Clear buffer
        }
        else {
            const simple = !this.renderingManager.fullRender && this.api.settingsDispatcher.settings.progressiveRendering &&
                performance.now() - this.renderingManager.renderer.renderStartTime > 10; // render full ssao if we have time for it
            const rt = simple ? this.ssaoSimpleRenderTarget : this.ssaoRenderTarget;
            const pass = simple ? this.ssaoSimplePass : this.ssaoPass;

            renderer.setRenderTarget(rt as any);
            renderer.setClearAlpha(0);
            renderer.clearColor();

            // make noise dynamic with time value change only if camera is moved
            if (!this.prevCameraMatrixWorld.equals(this.api.camera.matrixWorld)) {
                // @ts-ignore
                this.seedUniform.value = performance.now() % 100;
            }
            this.prevCameraMatrixWorld.copy(this.api.camera.matrixWorld);

            // @ts-ignore
            pass.render(renderer, rt);
            // TODO: Set buffer
        }
    }
}

// language=wgsl
const commonShaderFunctions = `
    // Camera relative position
    fn computePosition(coord: vec2<f32>, depthSample: f32) -> vec3<f32> {
        var U = object;
        var ndc: vec4<f32> = vec4((vec3(coord.x, 1.0 - coord.y, depthSample) - 0.5) * 2.0, 1.0);
        var clip: vec4<f32> = U.uProjectionInverse * ndc;
        return clip.xyz / clip.w;
    }
    
    // Treat background as surface with normal towards camera
    fn getNormal(normalSample: vec4<f32>) -> vec3<f32> {
        if (normalSample.x == 0.0 && normalSample.y == 0.0 && normalSample.z == 0.0) {
            return vec3(0.0, 0.0, 1.0);
        } else {
            return normalSample.xyz;
        }
    }
    
    fn getOffsetUvs(texelSize: vec2<f32>) -> array<vec2<f32>, 4> {
        var offsetUvs: array<vec2<f32>, 4>;
        offsetUvs[0] = vec2(texelSize.x, 0);
        offsetUvs[1] = vec2(-texelSize.x, 0);
        offsetUvs[2] = vec2(0, texelSize.y);
        offsetUvs[3] = vec2(0, -texelSize.y);
        return offsetUvs;
    }
`;

// language=wgsl
const edgeShader = `
    fn mainColor(vUv: vec2<f32>) -> vec4<f32> {
        var U = object;

        var correctedUv: vec2<f32> = vec2(vUv.x, 1.0 - vUv.y);
        var color: vec4<f32> = textureSample(inputBuffer, inputBuffer_sampler, correctedUv);
        var normalSample: vec4<f32> = textureSample(normalBuffer, normalBuffer_sampler, correctedUv);
        var normal: vec3<f32> = getNormal(normalSample);
        var flagsAndCutsSample = textureSample(flagsAndCutsBuffer, flagsAndCutsBuffer_sampler, correctedUv);
        var id = normalSample.a;
        var depthSample: f32 = textureSample(depthBuffer, depthBuffer_sampler, correctedUv);
        output.depth = depthSample; // This gets written to @builtin(frag_depth)
        var isBackground: bool = depthSample == 1.0;
        depthSample = (depthSample + 1.0) * 0.5; // Depth correction for WebGPU coordinates
        var position = computePosition(correctedUv, depthSample);
        var planeConstant: f32 = -dot(position, normal);
        var isEdge: bool = false;      
        var renderEdges: bool = flagsAndCutsSample.g > 0.0 || isBackground;
        var useNormals: bool = flagsAndCutsSample.g > 0.6 || isBackground;
        var preciseEdgeLines: bool = flagsAndCutsSample.g > 0.8;
        
        var offsetUvs: array<vec2<f32>, 4> = getOffsetUvs(U.texelSize);
        
        for (var i: i32 = 0; i < 4; i++) {
            var uv: vec2<f32> = correctedUv + offsetUvs[i];
            var offsetNormalSample: vec4<f32> = textureSample(normalBuffer, normalBuffer_sampler, uv);
            var offsetNormal: vec3<f32> = getNormal(offsetNormalSample);
            var offsetDepth: f32 = textureSample(depthBuffer, depthBuffer_sampler, uv);
            offsetDepth = (offsetDepth + 1.0) * 0.5; // Depth correction for WebGPU coordinates
            var offsetPosition: vec3<f32> = computePosition(uv, offsetDepth);
            var offsetFlagsAndCutsSample: vec4<f32> = textureSample(flagsAndCutsBuffer, flagsAndCutsBuffer_sampler, uv);
            var offsetId: f32 = offsetNormalSample.a;
            useNormals = useNormals || offsetFlagsAndCutsSample.g > 0.6;
            
            var normalDot: f32 = abs(dot(normal, offsetNormal));
            // Calculating delta distance to plane, extended from picked pixel, works fine for Flat rendering 
            // does not work well for Gouraud rendering, delta is increased by difference between interpolated normal and non-interpolated depth
            // this gives false positive edge TODO: maybe use derivative to calculate non-interpolated normal value (only for depthDelta calculation)? normalFromDepth = normalize(cross(dFdx(position), dFdy(position))) 
            var depthDelta = select(abs(position.z - offsetPosition.z) * 0.01, // TODO: without normal, some artifacts possible
                distancePointToPlane(normal, planeConstant, offsetPosition),
                useNormals);
            isEdge = isEdge ||
                abs(id - offsetId) > 0.99 || // edges between touching entities 
                (useNormals && normalDot < select(0.9, 0.99995, preciseEdgeLines)) || // edges on corners
                (depthDelta > 0.002 && // edges between parallel planes on different levels
                abs(depthDelta / position.z) > 0.0005); // depth precision fix, removes noise
        }
                
        var occlusion: f32 = textureSample(ssaoBuffer, ssaoBuffer_sampler, correctedUv).r;
        color = vec4(color.rgb * (1.0 - occlusion), color.a);
        
        var fadeArtifacts: f32 = 1.0;
        if (!isBackground) {
            fadeArtifacts = clamp((0.99987 - (depthSample + 1.0) / 2.0) * 10000.0, 0.0, 1.0); // Fade long distance precision artifacs
        }
        
        var weight: f32 = 0.0;
        if (isEdge) { weight = fadeArtifacts; }
        var outputColor: vec4<f32> = mix(color, vec4(color.rgb * 0.7, max(color.a, 0.4)), weight);
        
        return outputColor;
    }
    
    fn distancePointToPlane(planeNormal: vec3<f32>, planeConstant: f32, pointPos: vec3<f32>) -> f32 {
        return abs(dot(planeNormal, pointPos) + planeConstant);
    }
    
    ${ commonShaderFunctions }
`;

// language=wgsl
const ssaoShader = `
    fn mainColor(
        vPosition: vec4<f32>,
        vUv: vec2<f32>,
        resolution: vec2<f32>
    ) -> vec4<f32> {
        var U = object;

        var correctedUv: vec2<f32> = vec2(vUv.x, 1.0 - vUv.y);
        var normalSample: vec4<f32> = textureSample(normalBuffer, normalBuffer_sampler, correctedUv);
        var normal: vec3<f32> = getNormal(normalSample);
        var depthSample: f32 = textureSample(depthBuffer, depthBuffer_sampler, correctedUv);
        if (depthSample == 1.0) { discard; } // background
        
        var position = computePosition(correctedUv, depthSample);
        var fragCoord = 0.5 * (vPosition.xy + 1.0) * resolution;
        const noiseBufferSize: f32 = 100;
        var coords: vec2<u32> = vec2<u32>(modf((fragCoord + vec2(U.seed)) / noiseBufferSize).fract * noiseBufferSize);
        var rand: vec2<f32> = normalize(textureLoad(noiseBuffer, coords, 0).xy);
        
        const occlusionRadius: f32 = 32.0;
        var occlusion: f32 = 0.0;
        
        var offsetCount: i32 = 4;
        var offsetUvs: array<vec2<f32>, 4> = getOffsetUvs(U.texelSize);
        
        // TODO: Simple version
        
        for (var i: i32 = 0; i < offsetCount; i++) {
            var k1: vec2<f32> = reflect(offsetUvs[i], rand);
            var k2: vec2<f32> = vec2(k1.x * SIN45 - k1.y * SIN45, k1.x * SIN45 + k1.y * SIN45);
            k1 *= occlusionRadius;
            k2 *= occlusionRadius;
            
            occlusion += getOcclusion(position, normal, correctedUv + k1);
            occlusion += getOcclusion(position, normal, correctedUv + k2 * 0.75);
            occlusion += getOcclusion(position, normal, correctedUv + k1 * 0.5);
            occlusion += getOcclusion(position, normal, correctedUv + k2 * 0.25);
        }
        
        occlusion = clamp(occlusion / f32(4 * offsetCount), 0.0, 1.0);
        return vec4(occlusion);
    }
    
    const SIN45 = 0.707107;
    
    fn getOcclusion(position: vec3<f32>, normal: vec3<f32>, uv: vec2<f32>) -> f32 {
        const uBias: f32 = 0.04;
        const uAttenuation: vec2<f32> = vec2(1.0, 1.0);

        var offsetPosition: vec3<f32> = computePosition(uv, textureSample(depthBuffer, depthBuffer_sampler, uv));
        var positionVec: vec3<f32> = offsetPosition - position;
        var intensity: f32 = max(dot(normal, normalize(positionVec)) - uBias, 0.0);
        var attenuation: f32 = 1.0 / (uAttenuation.x + uAttenuation.y * length(positionVec));
        return intensity * attenuation;
    }
    
    ${ commonShaderFunctions }
`;
