import {
    WebGLRenderer,
    RenderTarget,
    DoubleSide,
    WebGLRenderTarget,
} from "three";
import {EffectPass} from "./../EffectPass.js";
import {Web3DCamera} from "./../Web3DCamera.js";
import {OutlineEffectNodeMaterial} from "./OutlineEffectNodeMaterial.js";
import {SelectionEffectPass} from "../SelectionEffectPass.js";
import {Settings} from "../../common.js";
import {
    MeshBasicNodeMaterial,
    wgslFn,
    uniform,
    color,
    float,
    uv,
    texture,
} from "../../Three.WebGPU.js";
import {GeometryObject3D} from "../../Model.js";
import {DisplayGroup, TypedRanges} from "../DisplayGroup.js";

export class WebGPUSelectionEffectPass extends SelectionEffectPass {
    maskMaterial: MeshBasicNodeMaterial;
    private readonly mainCamera: Web3DCamera;

    constructor(camera: Web3DCamera, settings: Settings) {
        const material = new MeshBasicNodeMaterial();

        // @ts-ignore
        super(material, settings);

        // Required to not break on WebGPU: TODO: Is this Three.js bug or can it be prevented?
        this.renderTargetOutline.depthBuffer = true;

        const uniforms = [
            uniform(color(settings.selectionColor)).label('color'),
            uniform(float(this.fillStrength)).label('fillStrength'),
            uniform(float(this.edgeStrength)).label('edgeStrength'),
            texture(this.renderTargetMask.texture).label('maskTexture'),
            texture(this.renderTargetOutline.texture).label('outlineTexture'),
        ];

        // language=wgsl
        const calculateColorNode = wgslFn(`
             fn calculateColor(uv: vec2<f32>) -> vec4<f32> {
                var U = object;
                var correctedUv: vec2<f32> = vec2(uv.x, 1.0 - uv.y);
                var edge: f32 = textureSample(outlineTexture, outlineTexture_sampler, correctedUv).r;
                var mask: f32 = textureSample(maskTexture, maskTexture_sampler, correctedUv).r;
                
                edge = clamp(edge * U.edgeStrength * mask, 0.0, 1.0);
                mask = (1.0 - mask) * U.fillStrength;
                return vec4<f32>(U.color, edge + mask);
            }
        `, uniforms);

        material.colorNode = calculateColorNode({ uv: uv() });
        material.transparent = true;
        material.depthTest = false;
        material.depthWrite = false;
        material.needsUpdate = true;

        this.mainCamera = camera;
        this.maskMaterial = new MeshBasicNodeMaterial();
        this.maskMaterial.colorNode = color(this.maskColor);
        this.maskMaterial.side = DoubleSide;
        // @ts-ignore
        this.outlinePass = new EffectPass(new OutlineEffectNodeMaterial(this.renderTargetMask.texture));
    }

    override render(renderer: WebGLRenderer, writeBuffer: RenderTarget): void {
        if (this.selectionScene.children.length > 0 || this.shouldClear) {

            // Render selection mask:
            renderer.setRenderTarget(this.renderTargetMask as any);
            renderer.setClearColor(0xffffff, 1);
            renderer.clear();
            renderer.render(this.selectionScene, this.mainCamera);

            // Render selection outline:
            this.outlinePass.render(renderer, this.renderTargetOutline as WebGLRenderTarget);

            // Render both onto writeBuffer:
            renderer.autoClear = false;
            renderer.autoClearColor = false;
            renderer.autoClearDepth = false;
            renderer.setRenderTarget(writeBuffer as any);
            renderer.render(this.scene, this.camera);

            this.shouldClear = false;
        }
    }

    // TODO: Remove this override when render ranges are supported in Web3D's WebGPU implementation
    override addObjectGroups(modelId: string, id: number | string, object: GeometryObject3D, groups: DisplayGroup[] | TypedRanges = [{start: 0, count: Infinity}]): SelectionEffectPass {
        const cloned = this.addClone(modelId, id, object);
        const geom = cloned.geometry;
        if (!Array.isArray(cloned.material))
            cloned.material = [cloned.material];
        geom.clearGroups();
        geom.addGroup(0, Infinity, 0); // Render everything
        return this;
    }
}
