import {
    Mesh,
    PlaneGeometry,
    Scene,
    ShaderMaterial,
    Texture,
    Vector2,
    Vector3,
    Vector4,
    WebGLRenderTarget,
    RenderTarget,
} from "three";
import { Web3DCamera } from "../Rendering/Web3DCamera.js";
import { SettingsDispatcher } from "../SettingsDispatcher.js";
import { Settings } from "../common.js";
import { RenderingManager } from "../Rendering/RenderingManager.js";
import {GPUPicker} from "./GPUPicker.js";

export class WebGLGPUPicker extends GPUPicker {
    private readonly pickScene = new Scene();
    private readonly depthCopyMaterial: ShaderMaterial;

    constructor(
        renderingManager: RenderingManager,
        private camera: Web3DCamera,
        private container: HTMLElement,
        settingsDispatcher: SettingsDispatcher<Settings>,
    ) {
        super(renderingManager, settingsDispatcher);

        const geometry = new PlaneGeometry(2, 2);
        this.depthCopyMaterial = new ShaderMaterial({
            uniforms: {
                depthTexture: { value: null as Texture },
                uvTransform: { value: new Vector4(0, 0, 1, 1) }
            },
            fragmentShader,
            vertexShader,
            premultipliedAlpha: true,
            depthTest: false,
            depthWrite: false
        });
        const plane = new Mesh(geometry, this.depthCopyMaterial);
        this.pickScene.add(plane);
    }

    /**
     * Pick the GPU depth buffer.
     * @param screenPosition The 2D screen coordinate to be picked
     * @returns A 3D vector representing the world space point of the pick
     */
    pick = (() => {
        const viewSize = new Vector2();
        const viewPosition = new Vector2();

        return async (screenPosition: Vector2): Promise<Vector3 | undefined> => {
            this.renderingManager.renderer.getSize(viewSize);
            const dimensions = this.container.getBoundingClientRect();
            const tolerance = this.settingsDispatcher.settings.navigationSnapDistance || 0;

            if (this.container.clientWidth === 0 || this.container.clientHeight === 0) return undefined;
            viewPosition.set(
                (screenPosition.x - dimensions.left) / this.container.clientWidth,
                (screenPosition.y - dimensions.top) / this.container.clientHeight
            );

            // this will perform 3 depth samples: 1px around pointer, tolerance px around pointer and whole screen average
            const depth = await this.pickDepth(viewPosition,
                new Vector2().setScalar(1).divide(viewSize),
                new Vector2().setScalar(tolerance).divide(viewSize));
            if (depth === undefined) return undefined;
            return GPUPicker.viewToWorldPoint(viewPosition, depth, this.camera);
        };
    })();

    async pickDepth(viewPosition: Vector2, tolerance1: Vector2, tolerance2: Vector2): Promise<number | undefined> {
        const depthTexture = this.renderingManager.composer.depthTexture;
        this.renderingManager.renderer.setClearColor(0x000000, 0);

        const setUVTransformPositionTolerance = (viewPosition: Vector2, tolerance: Vector2) => {
            this.depthCopyMaterial.uniforms.uvTransform.value.set(viewPosition.x - tolerance.x, 1 - viewPosition.y - tolerance.y, tolerance.x * 2, tolerance.y * 2);
        };

        const resetUVTransform = () => {
            this.depthCopyMaterial.uniforms.uvTransform.value.set(0, 0, 1, 1);
        };

        const getDepth = (depthTexture: Texture, rt: RenderTarget) => {
            this.depthCopyMaterial.uniforms.depthTexture.value = depthTexture;
            this.renderingManager.renderer.setRenderTarget(rt as WebGLRenderTarget);
            this.renderingManager.renderer.clear();
            this.renderingManager.renderer.render(this.pickScene, this.pickCamera);
            this.renderingManager.renderer.readRenderTargetPixels(rt as WebGLRenderTarget, 0, 0, rt.width, rt.height, this.depths);

            // Compute average of this.depths
            // We could downscale depthTexture to single pixel with GPU, but iOS mipmap generation algorithm doesnt give good average value this way
            for (let colorComponent = 0; colorComponent < 4; colorComponent++) {
                for (let i = 1; i < rt.width * rt.height; i++)
                    this.depths[colorComponent] += this.depths[colorComponent + i * 4];
                this.depths[colorComponent] /= rt.width * rt.height;
            }
            if (this.depths[1] === 1 || // indicates void
                this.depths[3] === 0)   // not enough geometry in the view
                return undefined;
            return this.depths[0] / (1 - this.depths[1]); // exclude averaged void pixels from average depth pixels
        };

        // Render single pixel of depthTexture
        setUVTransformPositionTolerance(viewPosition, tolerance1);
        let depth = getDepth(depthTexture, this.singlePixelRT);
        if (depth !== undefined) return depth;

        // Copy depthTexture to depthCopyRT to be able to use mipmaps for downscaling
        resetUVTransform();
        this.updateRenderTargetSize();
        this.depthCopyMaterial.uniforms.depthTexture.value = depthTexture;
        this.renderingManager.renderer.setRenderTarget(this.depthCopyRT as WebGLRenderTarget);
        this.renderingManager.renderer.clear();
        this.renderingManager.renderer.render(this.pickScene, this.pickCamera);

        // If void is picked, render part of depthCopyRT around mouse pointer with navigationSnapDistance
        setUVTransformPositionTolerance(viewPosition, tolerance2);
        depth = getDepth(this.depthCopyRT.texture, this.smallRT);
        if (depth !== undefined) return depth;

        // If still void is picked, render whole depthCopyRT
        resetUVTransform();
        depth = getDepth(this.depthCopyRT.texture, this.smallRT);
        return depth;
    }

    override dispose(): void {
        super.dispose();
        (this.pickScene.children[0] as Mesh).geometry.dispose();
        this.depthCopyMaterial.dispose();
    }

}

// language=GLSL
const fragmentShader = `
    uniform highp sampler2D depthTexture;
    varying vec2 vUv;

    void main() {
        vec4 depth = texture2D(depthTexture, vUv);
        if (depth.x == 1.0) // void
            depth = vec4(0.0, 1.0, 0.0, 0.0); // second channel indicates void
        gl_FragColor = depth;
    }
`;

// language=GLSL
const vertexShader = `
    varying vec2 vUv;
    uniform vec4 uvTransform;
    
    void main() {
        vUv = uvTransform.xy + uv * uvTransform.zw;
        vec4 modelViewPosition = modelViewMatrix * vec4(position, 1.0);
        gl_Position = projectionMatrix * modelViewPosition;
    }
`;
