import {Box3, Camera, Frustum, Plane, Ray, Raycaster, Vector2, Vector3} from "three";
import {IIntersection} from "./IIntersection.js";
import {Models} from "../Models.js";
import {Caster} from "./Caster.js";
import {Web3DCamera} from "../Rendering/Web3DCamera.js";
import {Model, NavigationPickableModel, PickableModel, SelectableModel, SnappedPickableModel} from "../Model.js";
import {PickComplexity, PickerCache} from "./PickerCache.js";
import {Settings, SnapType} from "../common.js";
import {SettingsDispatcher} from "../SettingsDispatcher.js";
import {RenderingManager} from "../Rendering/RenderingManager.js";
import {GPUPicker} from "./GPUPicker.js";
import {WebGPUGPUPicker} from "../Rendering/WebGPU/WebGPUGPUPicker.js";
import {WebGLGPUPicker} from "./WebGLGPUPicker.js";
import {frustumFromRay} from "../Helpers/common-utils.js";

export const NO_PICK_LAYER = 1;

export class Picker {
    private _worldBoundingBox: Box3;
    pickerBusy: boolean;
    private cache = new PickerCache();
    private facePlane = new Plane();
    private gpuPicker: GPUPicker;

    constructor(
        private _camera: Web3DCamera,
        private _models: Models,
        private _container: HTMLElement,
        private settingsDispatcher: SettingsDispatcher<Settings>,
        private renderingManager: RenderingManager,
    ) {
        _camera.subscribe(() => this.cache.clear());
        this.gpuPicker = renderingManager.isWebgpu() ?
            new WebGPUGPUPicker(this.renderingManager, _camera, _container, settingsDispatcher) :
            new WebGLGPUPicker(this.renderingManager, _camera, _container, settingsDispatcher);
    }

    public setWorldBoundingBox(box: Box3): void {
        this._worldBoundingBox = box;
    }

    createCaster(screenPosition: Vector2): Caster {
        const caster = new Caster();
        caster.screenPosition = screenPosition;
        const p = screenPositionToRayTracePoint(caster.screenPosition, this._container);
        caster.setFromCamera(p, this._camera);
        caster.snapDistance = this.settingsDispatcher.settings.snapDistance;
        caster.frustum = this.createFrustumFromScreenPoint(caster.screenPosition, caster.snapDistance);
        caster.layers.disable(NO_PICK_LAYER);
        return caster;
    }

    private createFrustumFromScreenPoint(screenPoint: Vector2, snapDistance: number): Frustum {
        return this.createFrustumFromScreenRect(screenPoint.clone().subScalar(snapDistance), screenPoint.clone().addScalar(snapDistance));
    }

    private createFrustumFromScreenRect = (() => {
        const top = new Plane();
        const right = new Plane();
        const bottom = new Plane();
        const left = new Plane();
        const near = new Plane();
        const far = new Plane();
        const topLeft = new Vector3();
        const topRight = new Vector3();
        const bottomLeft = new Vector3();
        const bottomRight = new Vector3();
        const topLeftNear = new Vector3();
        const bottomRightNear = new Vector3();
        const topLeftPoint = new Vector2();
        const bottomRightPoint = new Vector2();

        return (screenPointTopLeft: Vector2, screenPointBottomRight: Vector2): Frustum => {
            screenPositionToRayTracePoint(screenPointTopLeft, this._container, topLeftPoint);
            screenPositionToRayTracePoint(screenPointBottomRight, this._container, bottomRightPoint);

            const camera = this._camera;
            topLeft.set(topLeftPoint.x, topLeftPoint.y, 1).unproject(camera);
            topRight.set(bottomRightPoint.x, topLeftPoint.y, 1).unproject(camera);
            bottomLeft.set(topLeftPoint.x, bottomRightPoint.y, 1).unproject(camera);
            bottomRight.set(bottomRightPoint.x, bottomRightPoint.y, 1).unproject(camera);

            topLeftNear.set(topLeftPoint.x, topLeftPoint.y, 0).unproject(camera);
            bottomRightNear.set(bottomRightPoint.x, bottomRightPoint.y, 0).unproject(camera);

            top.setFromCoplanarPoints(topLeftNear, topLeft, topRight);
            right.setFromCoplanarPoints(bottomRightNear, topRight, bottomRight);
            bottom.setFromCoplanarPoints(bottomRightNear, bottomRight, bottomLeft);
            left.setFromCoplanarPoints(topLeftNear, bottomLeft, topLeft);

            camera.getWorldDirection(near.normal);
            near.constant = camera.position.length();
            far.normal = near.normal;
            far.constant = Infinity;

            return new Frustum(top, bottom, left, right, far, near);
        };
    })();

    private normalTowardsCamera(inn: IIntersection, caster: Caster): IIntersection {
        // turn normal towards camera, in case double sided material was picked from the back side
        if (inn && inn.normal && inn.point &&
            inn.normal.dot(inn.point.clone().sub(caster.ray.origin)) > 0) inn.normal.negate();
        return inn;
    }

    pickNearestModelCenter(screenPosition: Vector2): Vector3 {
       const caster = this.createCaster(screenPosition);
        let model: Model;
        const center = new Vector3();
        let minDist = Infinity;
        for (const m of this._models.getModels()) {
            if (m.getModelBoundingBox().isEmpty()) continue;
            m.getModelBoundingBox().getCenter(center);
            const dist = caster.ray.distanceSqToPoint(center);
            if (dist < minDist) {
                minDist = dist;
                model = m;
            }
        }
        return model ? model.getModelBoundingBox().getCenter(center) : undefined;
    }

    /**
     * Pick the scene for navigation, when only an intersection point is required.
     */
    async pickForNavigation(screenPosition: Vector2): Promise<Vector3 | undefined> {
        const caster = this.createCaster(screenPosition);
        const results: IIntersection[] = [];
        for (const model of this._models.getModels()) {
            const m = model as Model & NavigationPickableModel;
            if (m.pickNavigation) {
                const ii = m.pickNavigation(caster);
                if (ii)
                    results.push(ii);
            }
        }
        const gpuResult = await this.gpuPicker.pick(screenPosition);
        if (gpuResult.point && gpuResult.tolerance) {
            results.push({
                point: gpuResult.point,
                distance: gpuResult.depth,
                distanceToRay: 0
            } as IIntersection);
        }
        if (results.length > 0) {
            return this.reduceIntersections(results).point;
        }
        // nothing was picked, use gpu pick result calculated on whole screen
        if (gpuResult.tolerance === 0 && gpuResult.point) {
            return gpuResult.point;
        }
    }

    pick(screenPosition: Vector2, models: Model[]): Promise<IIntersection> {
        return this.cache.execCached(screenPosition, models, PickComplexity.NORMAL, async () => {
            const caster = this.createCaster(screenPosition);
            return this.normalTowardsCamera(this.reduceIntersections(await this.pickAllModels(caster, models)), caster);
        });
    }

    async pickSnapped(screenPosition: Vector2, snapTypes: SnapType[]): Promise<IIntersection> {
        const tolerance = 0.01;
        const intersection = await this.cache.execCached(screenPosition, undefined, PickComplexity.SNAPPED, async () => {
            const caster = this.createCaster(screenPosition);
            let intersections = await this.pickAllModelsSnapped(caster, snapTypes);

            // if we have a face hit, filter out all points behind this face
            if (intersections && intersections.length) {
                // TODO: use GPU pick instead
                const faceIntersection = this.normalTowardsCamera(this.reduceIntersections((await this.pickAllModels(caster, undefined)).filter(inn => !!inn.normal)), caster);
                if (faceIntersection) {
                    this.facePlane.setFromNormalAndCoplanarPoint(faceIntersection.normal, faceIntersection.point.sub(faceIntersection.normal.clone().multiplyScalar(tolerance)));
                    // Skip plane filtering for center line picks, as those are typically inside the geometry, behind the face plane
                    intersections = intersections.filter(inn => !inn.model || (inn.model.modelId === faceIntersection.model.modelId && inn.snapType === SnapType.CENTER_LINE) || this.facePlane.distanceToPoint(inn.point) >= 0);
                }
            }
            return this.normalTowardsCamera(this.reduceIntersections(intersections), caster);
        });
        // Filter cached interesction with unwanted snap type
        // TODO: avoid caching unwanted snap types, instead of filtering (cache each model separately maybe?)
        return (intersection && snapTypes.includes(intersection.snapType)) ? intersection : null;
    }

    public async getIntersectionFromScreenRect(screenPointTopLeft: Vector2, screenPointBottomRight: Vector2, excludeIntersected: boolean = false): Promise<IIntersection[]> {
        const caster = new Caster();
        delete caster.ray;
        caster.frustum = this.createFrustumFromScreenRect(screenPointTopLeft, screenPointBottomRight);
        const picks = [];
        const models = this._models.getModels();

        for (const model of models) {
            const m = model as Model & SelectableModel;
            if (m.isSelectable) picks.push(m.areaPick(caster, excludeIntersected));
        }

        let intersections: (IIntersection)[] = await Promise.all(picks);
        intersections = intersections.filter(item => !!item);

        return intersections;
    }

    async pickRay(ray: Ray, models?: Model[]): Promise<IIntersection> {
        const caster = new Caster();
        caster.ray = ray;
        caster.frustum = new Frustum();
        frustumFromRay(caster.frustum, caster.ray);
        return this.normalTowardsCamera(this.reduceIntersections(await this.pickAllModels(caster, models)), caster);
    }

    private async pickAllModels(caster: Caster, models: Model[]): Promise<IIntersection[]> {
        const promises: Array<Promise<IIntersection>> = [];
        for (const model of models ? models : this._models.getModels()) {
            const m = model as Model & PickableModel;
            if (m.pick)
                promises.push(m.pick(caster));
        }
        const intersections = await Promise.all(promises);
        return intersections.filter(i => !!i);
    }

    private async pickAllModelsSnapped(caster: Caster, snapTypes: SnapType[]): Promise<IIntersection[]> {
        this.pickerBusy = true;

        const promises: Array<Promise<IIntersection[]>> = [];
        for (const model of this._models.getModels()) {
            const m = model as Model & SnappedPickableModel;
            if (m.pickSnapped)
                promises.push(m.pickSnapped(caster, snapTypes));
        }
        const ins = await Promise.all(promises);

        this.pickerBusy = false;
        return ins.flat().filter(i => !!i);
    }

    reduceIntersections(intersections: IIntersection[]): IIntersection {
        return !intersections || intersections.length === 0 ? undefined : intersections.reduce((a, b) => {
            // when comparing edges or points that are closer than all the faces, compare by distance to ray not distance to ray origin.
            const deltaDistance = (a.distanceToRay !== undefined && b.distanceToRay !== undefined) ?
                (a.distanceToRay * 5 + a.distance) - (b.distanceToRay * 5 + b.distance) :
                a.distance - b.distance;
            const deltaPriority = (a.pickPriority ?? 0) - (b.pickPriority ?? 0);
            const deltaSnapType = (a.snapType ?? 0) - (b.snapType ?? 0);

            return (deltaPriority === 0 ? (deltaSnapType === 0 ? deltaDistance : -deltaSnapType) : -deltaPriority) <= 0 ? a : b;
        });
    }

    dispose(): void {
        this.gpuPicker.dispose();
    }
}

const rayCaster = new Raycaster();
export function screenPositionToRay(point: { x: number; y: number }, container: HTMLElement, camera: Camera, out: Ray = new Ray()): Ray {
    rayCaster.setFromCamera(screenPositionToRayTracePoint(point, container), camera);
    return out.copy(rayCaster.ray);
}

export function screenPositionToRayTracePoint(point: { x: number; y: number }, container: HTMLElement, out: Vector2 = new Vector2()): Vector2 {
    const dimensions = container.getBoundingClientRect();
    out.x = ((point.x - dimensions.left) / container.clientWidth) * 2 - 1;
    out.y = -((point.y - dimensions.top) / container.clientHeight) * 2 + 1;
    return out;
}

export const worldToScreenPoint = (() => {
    const pos = new Vector3();

    return (point: Vector3, camera: Camera, canvasWidth: number, canvasHeight: number) => {
        pos.copy(point).applyMatrix4(camera.matrixWorldInverse);
        if (pos.z > 0) // point behind camera
            return;

        pos.applyMatrix4(camera.projectionMatrix);
        pos.x = ((pos.x + 1) * canvasWidth) / 2;
        pos.y = ((-pos.y + 1) * canvasHeight) / 2;

        return pos;
    };
})();
