import React, { useCallback, useEffect, useState } from "react";
import * as THREE from "three";
import { extendMaterial } from "../../../lib/ExtendMaterial";
import { MeshBVH, acceleratedRaycast } from "three-mesh-bvh";
import {
  IRR_X_VALS,
  IRR_Y_VALS_BOOST,
  IRR_Y_VALS_SETTLED,
  ANG_X_RAD,
  ANG_Y_PERC,
} from "../../../lib/Akima";
import { TooltipData } from "../../../components/HeatmapTooltip";
import {
  createAmbientLight,
  createDirectionalLight,
  createLUTTexture,
  generateRaysInConePointingDownWithAngle,
  calcWeightedIrrFromCeiling,
  isFloor,
  clickRaycast,
} from "../../../utils/threejs";
import { UseRendererRes } from "../useRenderer";
import simHeatmapShaders from "./simHeatmapShaders";

// Accelerated raycasting for meshes
THREE.Mesh.prototype.raycast = acceleratedRaycast;

export interface IrrRange {
  minIrr: number;
  maxIrr: number;
}

export interface ZenerRaycastResult {
  name: string;
  distance: number;
  irradiance: number;
}

export interface RaycastPoint {
  pos: THREE.Vector3;
  intersectionObj: THREE.Object3D;
  // Zeners with line-of-sight
  zenersWithLOSRaycastResults: ZenerRaycastResult[];
  totalIrradiance: number;
}

export interface UseIrrHeatmapRes {
  irrRange: IrrRange | null;
  planeIrrRange: IrrRange | null;
}

function calcZenerLOSRaycastResultsToPoint(
  raycaster: THREE.Raycaster,
  zenerPoss: THREE.Vector3[],
  point: THREE.Vector3,
  model: THREE.Object3D,
): ZenerRaycastResult[] {
  const zenersWithLOS: ZenerRaycastResult[] = [];

  zenerPoss.forEach((position, index) => {
    raycaster.set(position, point.clone().sub(position).normalize());
    const shadowIntersects = raycaster.intersectObject(model, true);

    const isBlocked =
      shadowIntersects.length > 0 &&
      shadowIntersects[0].distance.toFixed(8) <
        point.distanceTo(position).toFixed(8);

    if (!isBlocked) {
      zenersWithLOS.push({
        name: `Zener ${index + 1}`,
        distance: point.distanceTo(position),
        irradiance: calcWeightedIrrFromCeiling(position, point),
      });
    }
  });

  return zenersWithLOS;
}

function raycastInBottomHemisphere(
  zenerPoss: THREE.Vector3[],
  model: THREE.Object3D,
): RaycastPoint[] {
  const raycaster = new THREE.Raycaster();
  const numberOfRays = 1000;

  const tgtPoints: RaycastPoint[] = [];

  const fovAngle = Math.PI / 2; // field of view
  const fovRays = generateRaysInConePointingDownWithAngle(
    fovAngle,
    numberOfRays,
  );

  // Repeat raycast from the position of each Zener.
  zenerPoss.forEach((zenerPos: THREE.Vector3) => {
    // Perform raycast for respective Zener.
    fovRays.forEach((rayDir) => {
      const ray = new THREE.Ray(zenerPos.clone(), rayDir);

      raycaster.set(ray.origin, ray.direction);

      const intersects = raycaster.intersectObject(model, true);

      if (intersects.length === 0) {
        return;
      }

      const intersect = intersects[0];
      const point = intersect.point;

      // Check which other Zeners have line of sight to this point.
      const zenersWithLOSRaycastResults = calcZenerLOSRaycastResultsToPoint(
        raycaster,
        zenerPoss,
        point,
        model,
      );

      const raycastPoint: RaycastPoint = {
        pos: point,
        intersectionObj: intersect,
        zenersWithLOSRaycastResults,
        totalIrradiance: zenersWithLOSRaycastResults.reduce(
          (acc: number, curr: ZenerRaycastResult) => acc + curr.irradiance,
          0,
        ),
      };

      tgtPoints.push(raycastPoint);
    });
  });

  return tgtPoints;
}

function raycastInConeWithTargetAndRadius(
  raycaster: THREE.Raycaster,
  model: THREE.Object3D,
  radius: number,
  numRays: number,
  zenerPoss: THREE.Vector3[],
  target: THREE.Vector3,
): RaycastPoint[] {
  const points: RaycastPoint[] = [];

  // Raycast from each Zener to point.
  zenerPoss.forEach((zenerPos: THREE.Vector3) => {
    // Perform individual raycast.
    for (let i = 0; i < numRays; i++) {
      const angle = (i / numRays) * 2 * Math.PI; // Full circle
      const offsetX = radius * Math.cos(angle);
      const offsetY = radius * Math.sin(angle);

      const dir = target
        .clone()
        .add(new THREE.Vector3(offsetX, offsetY, 0))
        .sub(zenerPos)
        .normalize();

      raycaster.set(zenerPos, dir);

      const intersects = raycaster.intersectObject(model, true);

      if (intersects.length === 0) {
        continue;
      }

      const intersect = intersects[0];
      const point = intersect.point;

      // Check which other Zeners have line of sight to this point.
      const zenersWithLOSRaycastResults = calcZenerLOSRaycastResultsToPoint(
        raycaster,
        zenerPoss,
        point,
        model,
      );

      const raycastPoint: RaycastPoint = {
        pos: point,
        intersectionObj: intersect,
        zenersWithLOSRaycastResults,
        totalIrradiance: zenersWithLOSRaycastResults.reduce(
          (acc: number, curr: ZenerRaycastResult) => acc + curr.irradiance,
          0,
        ),
      };

      points.push(raycastPoint);
    }
  });

  return points;
}

function calcIrrRange(
  points: RaycastPoint[],
  pointPredicate?: (pt: RaycastPoint) => boolean,
): IrrRange {
  let minIrr = Infinity;
  let maxIrr = -Infinity;

  points.forEach((point: RaycastPoint) => {
    if (pointPredicate !== undefined && !pointPredicate(point)) {
      return;
    }

    if (point.totalIrradiance > 0 && point.totalIrradiance < minIrr) {
      // Ignore values less 0
      minIrr = point.totalIrradiance;
    }
    if (point.totalIrradiance > maxIrr) {
      maxIrr = point.totalIrradiance;
    }
  });

  return { minIrr, maxIrr };
}

function createSimLights(zenerPoss: THREE.Vector3[]) {
  const lightGroup = new THREE.Group();
  lightGroup.name = "ZenerLights";

  zenerPoss.forEach((zenerPos: THREE.Vector3, index: number) => {
    const targetPosition = new THREE.Vector3(zenerPos.x, zenerPos.y, 0); // Target position (floor center)
    const spotLight = createDirectionalLight(zenerPos, targetPosition, 0.9);
    spotLight.userData.zenerIndex = index;
    lightGroup.add(spotLight);
  });
  lightGroup.add(createAmbientLight(0xffffff, 0.1));

  return lightGroup;
}

function createMarkerMesh(color: number = 0x000000) {
  const markerGeometry = new THREE.SphereGeometry(0.05, 16, 16);
  const markerMaterial = new THREE.MeshBasicMaterial({
    color,
    opacity: 0.5,
    transparent: true,
  });
  const markerMesh = new THREE.Mesh(markerGeometry, markerMaterial);
  return markerMesh;
}

function replaceGeometryWithBvh(model: THREE.Object3D) {
  model.traverse((child) => {
    if (
      child instanceof THREE.Mesh &&
      child.geometry instanceof THREE.BufferGeometry &&
      !child.geometry.boundsTree
    ) {
      child.geometry.dispose(); // Dispose of old geometry data
      const bvh = new MeshBVH(child.geometry);
      child.geometry.boundsTree = bvh;
    }
  });
}

export default function useIrrHeatmap(
  canvasRef: React.RefObject<HTMLCanvasElement>,
  renderCtx: UseRendererRes,
  zenerPoss: THREE.Vector3[],
  setTooltip: (tooltip: TooltipData | null) => void,
  showIrradiance: boolean,
  isDraggingZener: boolean,
  exposureTimeS?: number,
): UseIrrHeatmapRes {
  const { scene, renderer, camera } = renderCtx;

  const raycaster = new THREE.Raycaster();
  const [markerMesh, setMarkerMesh] = useState<THREE.Mesh | null>(null);

  // ALL WEIGHTED AND WITHOUT DUTY CYCLE.
  const [irrRange, setIrrRange] = useState<IrrRange | null>(null);
  const [planeIrrRange, setPlaneIrrRange] = useState<IrrRange | null>(null);

  const enableShadowMap = useCallback(() => {
    renderer.shadowMap.enabled = true;
    renderer.shadowMap.type = THREE.PCFSoftShadowMap;
  }, [renderer]);

  const disableShadowMap = useCallback(() => {
    renderer.shadowMap.enabled = false;
  }, [renderer]);

  const roomModelHash =
    scene?.getObjectByName("RoomModel")?.userData.uvxRoomModelUpdatedAt;

  // Compute raycast points.
  useEffect(() => {
    if (!showIrradiance || isDraggingZener) {
      return;
    }

    const roomModel = scene.getObjectByName("RoomModel");
    if (!roomModel) {
      return;
    }

    // --- Perform a coarse raycast from all Zeners.

    // Perform raycast.
    const nextRaycastPoints: RaycastPoint[] = raycastInBottomHemisphere(
      zenerPoss,
      roomModel,
    );

    // Compute coarse irradiance range.
    const nextIrrRange: IrrRange = calcIrrRange(nextRaycastPoints);

    // --- Refine raycast for high-irradiance points.

    // Filter for high-irradiance points.
    const highIrrPoints = nextRaycastPoints.filter(
      (point) =>
        point.totalIrradiance &&
        point.totalIrradiance >= nextIrrRange.maxIrr * 0.9, // gathering high irradiance points
    );

    const radius = 0.1; // Radius for circular sampling
    const numRays = 20; // Number of rays per circle

    // Perform fine-grained raycasts to high irradiance points.
    highIrrPoints.forEach((tgtPoint: RaycastPoint) => {
      const hiPoints = raycastInConeWithTargetAndRadius(
        raycaster,
        roomModel,
        radius,
        numRays,
        zenerPoss,
        tgtPoint.pos,
      );

      // Update irradiance range.
      highIrrPoints.forEach((p: RaycastPoint) => {
        if (p.totalIrradiance > nextIrrRange.maxIrr) {
          nextIrrRange.maxIrr = p.totalIrradiance;
        }
      });

      // Add fine-grained points to collection.
      nextRaycastPoints.push(...hiPoints);
    });

    if (
      Number.isFinite(nextIrrRange.minIrr) &&
      Number.isFinite(nextIrrRange.maxIrr)
    ) {
      // FIXME: Need performant way to find minimum irradiance.
      // Raycasting misses min irradiance, this is a safe assumption.
      nextIrrRange.minIrr = 0;
      setIrrRange(nextIrrRange);
    }

    const nextPlaneIrrRange = calcIrrRange(nextRaycastPoints, (p) =>
      isFloor(p.intersectionObj),
    );
    if (
      Number.isFinite(nextPlaneIrrRange.minIrr) &&
      Number.isFinite(nextPlaneIrrRange.maxIrr)
    ) {
      setPlaneIrrRange(nextPlaneIrrRange);
    }
  }, [scene, zenerPoss, roomModelHash, showIrradiance, isDraggingZener]);

  // Update textures with heatmap.
  useEffect(() => {
    if (!scene || !zenerPoss.length || !irrRange || !showIrradiance) {
      return;
    }

    const roomModel = scene.getObjectByName("RoomModel");
    if (!roomModel) {
      return;
    }

    console.log("heatmap rendering...");

    enableShadowMap();

    // Create lights.
    const lightGroup = createSimLights(zenerPoss);
    scene.add(lightGroup);

    // Create a transparent whitish marker
    const nextMarkerMesh = createMarkerMesh();
    scene.add(nextMarkerMesh);
    setMarkerMesh(nextMarkerMesh);

    // Replace geometry with BVH.
    replaceGeometryWithBvh(roomModel);

    const lutTexture = createLUTTexture(200);
    lutTexture.needsUpdate = true;

    const combinedMaterial = extendMaterial(THREE.MeshPhongMaterial, {
      class: THREE.ShaderMaterial,
      explicit: true,

      ...simHeatmapShaders(
        zenerPoss,
        lutTexture,
        irrRange,
        IRR_X_VALS,
        IRR_Y_VALS_SETTLED,
        IRR_Y_VALS_BOOST,
        ANG_X_RAD,
        ANG_Y_PERC,
      ),

      material: {
        polygonOffset: true,
        polygonOffsetFactor: -0.1,
        side: THREE.DoubleSide,
      },
    });

    roomModel.traverse((child) => {
      if (child instanceof THREE.Mesh) {
        child.castShadow = true;
        child.receiveShadow = true;

        // Save original material.
        child.userData.originalMaterial = child.material;
        child.material = combinedMaterial;

        child.material.needsUpdate = true;
      }
    });

    return () => {
      disableShadowMap();

      roomModel.traverse((child) => {
        if (child instanceof THREE.Mesh) {
          child.material.dispose();
          child.material = child.userData.originalMaterial;
          child.userData.originalMaterial = undefined;
          child.material.needsUpdate = true;
        }
      });

      scene.remove(lightGroup);
      scene.remove(nextMarkerMesh);
      nextMarkerMesh.material.dispose();
    };
  }, [roomModelHash, scene, zenerPoss, irrRange, showIrradiance]);

  // Register mouse move handler.
  useEffect(() => {
    if (!markerMesh || !camera) {
      return;
    }

    const roomModel = scene.getObjectByName("RoomModel");
    if (!roomModel) {
      return;
    }

    const onMouseMove = (event: MouseEvent) => {
      if (!canvasRef.current) {
        return;
      }

      const intersects = clickRaycast(
        canvasRef.current,
        scene,
        camera,
        event,
        roomModel,
      );

      if (intersects.length > 0) {
        const intersect = intersects[0];
        const point = intersect.point;

        const zenersWithLOSRaycastResults = calcZenerLOSRaycastResultsToPoint(
          raycaster,
          zenerPoss,
          point,
          roomModel,
        );

        const raycastPoint: RaycastPoint = {
          pos: point,
          intersectionObj: intersect,
          zenersWithLOSRaycastResults,
          totalIrradiance: zenersWithLOSRaycastResults.reduce(
            (acc: number, curr: ZenerRaycastResult) => acc + curr.irradiance,
            0,
          ),
        };

        let dose: number | undefined = undefined;

        if (exposureTimeS && raycastPoint.totalIrradiance > 0) {
          dose = exposureTimeS * raycastPoint.totalIrradiance;
        }

        const rect = canvasRef.current.getBoundingClientRect();
        setTooltip({
          x: event.clientX - rect.left,
          y: event.clientY - rect.top,
          pointData: {
            raycastPoint,
            dose,
          },
        });

        markerMesh.position.copy(point);
        markerMesh.visible = true;
      } else {
        markerMesh.visible = false;
        setTooltip(null);
      }
    };

    window.addEventListener("mousemove", onMouseMove);
    return () => {
      window.removeEventListener("mousemove", onMouseMove);
    };
  }, [roomModelHash, markerMesh, camera, exposureTimeS]);

  return {
    irrRange,
    planeIrrRange,
  };
}
