import {Intersection, Mesh, Quaternion, Sphere, Vector2, Vector3, Vector4} from "three";
import {Caster} from "../Picker/Caster.js";
import {Web3DCamera} from "../Rendering/Web3DCamera.js";
import {Web3DMeshPointsMaterial} from "../Rendering/Web3DMeshPointsMaterial.js";
import {MeshPointsGeometry} from "./MeshPointsGeometry.js";
import {worldToScreenPoint} from "../Picker/Picker.js";
import {Vector3Const} from "../Helpers/common-utils.js";

export class MeshPoints extends Mesh {
    protected sphere = new Sphere();
    private point = new Vector3();
    declare material: Web3DMeshPointsMaterial | Web3DMeshPointsMaterial[];
    declare geometry: MeshPointsGeometry;

    useBillboardPicking = false;

    constructor(geometry: MeshPointsGeometry, material: Web3DMeshPointsMaterial | Web3DMeshPointsMaterial[], private _camera: Web3DCamera, private _container: HTMLElement) {
        super(geometry, material);
    }

    override raycast(raycaster: Caster, intersects: Intersection[]): void {
        const geometry = this.geometry;
        const matrixWorld = this.matrixWorld;

        // Checking boundingSphere distance to ray:
        if (!this.useBillboardPicking) {
            if (geometry.boundingSphere === null) geometry.computeBoundingSphere();
            this.sphere.copy(geometry.boundingSphere);
            this.sphere.applyMatrix4(matrixWorld);
            this.sphere.radius += this._getSnapWorldRadius(this.sphere.center.distanceTo(raycaster.ray.origin), raycaster.snapDistance);
            if (raycaster.ray.intersectsSphere(this.sphere) === false) return;
        }

        const posAttribute = this.geometry.attributes.instancePosition;
        const raycast = (i: number) => {
            this.point.fromArray(posAttribute.array, i * 3).applyMatrix4(matrixWorld);
            if (this.useBillboardPicking) this._raycastBillboard(i, raycaster, intersects);
            else this._raycastPoint(this.point, i, raycaster, intersects);
        };

        if (this.geometry.groups && this.geometry.groups.length > 0) {
            for (const g of this.geometry.groups as any[]) {
                for (let i = g.instanceOffset; i < g.instanceCount; i++)
                    raycast(i);
            }
        } else {
            for (let i = 0; i < posAttribute.count; i++)
                raycast(i);
        }
    }

    protected _getSnapWorldRadius(cameraDistance: number, snapDistance: number): number {
        const firstMaterial = Array.isArray(this.material) ? this.material[0] : this.material;
        if (firstMaterial.sizeAttenuation) return firstMaterial.size;
        return this._camera.getViewWorldSize(cameraDistance) * (firstMaterial.size + (snapDistance || 0)) / this._container.clientWidth;
    }

    protected _raycastPoint(point: Vector3, index: number, raycaster: Caster, intersects: Intersection[]): void {
        const cameraDistance = point.distanceTo(raycaster.ray.origin);
        const snapRadius = this._getSnapWorldRadius(cameraDistance, raycaster.snapDistance);
        const rayPointDistanceSq = raycaster.ray.distanceSqToPoint(point);
        if (rayPointDistanceSq < snapRadius * snapRadius) {
            intersects.push({
                distance: cameraDistance,
                distanceToRay: Math.sqrt(rayPointDistanceSq),
                point: point.clone(),
                index: index,
                face: null,
                object: this,
            });
        }
    }

    protected _raycastBillboard = (() => {
        const point = new Vector3();
        const pointSize = new Vector4();
        const cornerUL = new Vector2();
        const cornerLR = new Vector2();
        const camPos = new Vector3();

        return (index: number, raycaster: Caster, intersects: Intersection[]): void => {
            if (!raycaster.screenPosition) return; // Billboard picking not supported e.g. in XR

            const posAttribute = this.geometry.attributes.instancePosition;
            const sizeAttribute = this.geometry.attributes.pointSize;
            const firstMaterial = Array.isArray(this.material) ? this.material[0] : this.material;

            point.fromArray(posAttribute.array, index * 3).applyMatrix4(this.matrixWorld);
            if (sizeAttribute) pointSize.fromArray(sizeAttribute.array, index * 4);
            else pointSize.set(-1, -1, 1, 1).multiplyScalar(0.5 * firstMaterial.size);

            const rectFound = firstMaterial.sizeAttenuation ?
                this.getAttenuatedScreenRect(point, pointSize, cornerUL, cornerLR) :
                this.getNonAttenuatedScreenRect(point, pointSize, cornerUL, cornerLR);
            if (!rectFound) return;

            if (
                cornerUL.x <= raycaster.screenPosition.x && raycaster.screenPosition.x <= cornerLR.x &&
                cornerUL.y <= raycaster.screenPosition.y && raycaster.screenPosition.y <= cornerLR.y
            ) {
                const d = point.distanceTo(this._camera.getWorldPosition(camPos));

                intersects.push({
                    distance: d,
                    distanceToRay: 0,
                    point: raycaster.ray.origin.clone().addScaledVector(raycaster.ray.direction, d),
                    index: index,
                    face: null,
                    object: this,
                });
            }

        };
    })();

    private getNonAttenuatedScreenRect(point: Vector3, pointSize: Vector4, dstUL: Vector2, dstLR: Vector2): boolean {
        const screenP = worldToScreenPoint(point, this._camera, this._container.clientWidth, this._container.clientHeight);
        if (!screenP) return false;
        dstUL.set(screenP.x + pointSize.x, screenP.y + pointSize.y);
        dstLR.set(screenP.x + pointSize.z, screenP.y + pointSize.w);
        return true;
    }

    private getAttenuatedScreenRect = (() => {
        const ul = new Vector3();
        const lr = new Vector3();
        const camX = new Vector3();
        const camY = new Vector3();
        const quat = new Quaternion();

        return (point: Vector3, pointSize: Vector4, dstUL: Vector2, dstLR: Vector2): boolean => {
            this._camera.getWorldQuaternion(quat);
            camX.copy(Vector3Const.right).applyQuaternion(quat);
            camY.copy(Vector3Const.threejsUp).applyQuaternion(quat);
            const w = this._container.clientWidth;
            const h = this._container.clientHeight;

            ul.copy(point).addScaledVector(camX, 2 * pointSize.x).addScaledVector(camY, 2 * -pointSize.y);
            const screenUL = worldToScreenPoint(ul, this._camera, w, h);
            if (!screenUL) return false;
            dstUL.set(screenUL.x, screenUL.y);

            lr.copy(point).addScaledVector(camX, 2 * pointSize.z).addScaledVector(camY, 2 * -pointSize.w);
            const screenLR = worldToScreenPoint(lr, this._camera, w, h);
            if (!screenLR) return false;
            dstLR.set(screenLR.x, screenLR.y);

            return true;
        };
    })();
}
