import Plotly, { ColorScale } from "plotly.js-dist-min";
import { PmmData } from "./PmmData"; // assuming your type definitions are in a file named types.ts

type Scatter3dTrace = {
  type: "scatter3d";
  mode: "lines" | "markers";
  x: number[];
  y: number[];
  z: number[];
  line?: { color: string; width: number };
  marker?: { color: string; size: number };
  hoverinfo?: "skip";
};

type SurfaceTrace = {
  type: "surface";
  x: number[][];
  y: number[][];
  z: number[][];
  opacity?: number;
  colorscale?: ColorScale;
  showscale?: boolean;
};

type ConeTrace = {
  type: "cone";
  x: number[];
  y: number[];
  z: number[];
  u: number[];
  v: number[];
  w: number[];
  sizemode: string;
  colorscale: ColorScale;
  showscale: boolean;
  hoverinfo?: "skip";
};

// Function to generate the PMM 3D plot
export function plotPmm(pmmData: PmmData, elementId: string, fullScreenMode?: boolean): void {
  const { xVals: X, yVals: Y, zVals: Z, loadCombos } = pmmData;

  // Define colors and labels for axes
  const axisColors = { x: "#FF0000", y: "#00CC00", z: "#0000FF" };
  const axLabels = { x: "${phi}M_{nx}$", y: "${phi}M_{ny}$", z: "${phi}P_n$" };

  // Calculate axis ranges and factors
  const zFactor = 0.12;
  const xyFactor = 0.05;

  const minMax = [
    [Math.min(...X.flat()), Math.max(...X.flat())],
    [Math.min(...Y.flat()), Math.max(...Y.flat())],
    [Math.min(...Z.flat()), Math.max(...Z.flat())],
  ].map(([min, max]) => [min, max, max - min]);

  const zRange = minMax[2][2] * (1 + 2 * zFactor);
  const zMin = minMax[2][0] - zRange * zFactor;
  const zMax = minMax[2][1] + zRange * zFactor;

  const maxXY = Math.max(minMax[0][2], minMax[1][2]);
  const minAspect = 1;
  const xRange = Math.min((1 + xyFactor) * maxXY, minAspect * minMax[0][2]) / 2;
  const yRange = Math.min((1 + xyFactor) * maxXY, minAspect * minMax[1][2]) / 2;

  // Helper function to create axis arrow
  const createAxisArrow = (axis: "x" | "y" | "z") => {
    const color = axisColors[axis];

    const arrowBody: Scatter3dTrace = {
      type: "scatter3d",
      mode: "lines",
      line: { color, width: 5 },
      x: axis === "x" ? [-xRange, xRange] : [0, 0],
      y: axis === "y" ? [-yRange, yRange] : [0, 0],
      z: axis === "z" ? [zMin, zMax] : [0, 0],
      hoverinfo: "skip",
    };

    const arrowHead: ConeTrace = {
      type: "cone",
      x: [axis === "x" ? xRange : 0],
      y: [axis === "y" ? yRange : 0],
      z: [axis === "z" ? zMax : 0],
      u: [axis === "x" ? 0.1 * xRange : 0],
      v: [axis === "y" ? 0.1 * yRange : 0],
      w: [axis === "z" ? 0.1 * zRange : 0],
      sizemode: "scaled",
      colorscale: [
        [0, color],
        [1, color],
      ],
      showscale: false,
      hoverinfo: "skip",
    };

    return [arrowBody, arrowHead];
  };

  const annotationCommon = {
    showarrow: false,
    font: { color: "#1f1f1f", size: 18 },
    x: 0,
    y: 0,
    z: 0,
  };

  // Create layout with scene settings
  const layout = {
    title: {
      text: "PMM Diagram",
      font: { size: 18, color: "#1f1f1f" },
      x: 0.5,
    },
    scene: {
      xaxis: { visible: false },
      yaxis: { visible: false },
      zaxis: { visible: false },
      annotations: [
        {
          ...annotationCommon,
          text: axLabels.x,
          x: xRange,
        },
        {
          ...annotationCommon,
          text: axLabels.y,
          y: yRange,
        },
        {
          ...annotationCommon,
          text: axLabels.z,
          z: zMax,
        },
      ],
      aspectmode: "auto" as "cube" | "auto" | "data" | "manual",
    },
    width: fullScreenMode ? undefined : 350,
    height: fullScreenMode ? undefined : 350,
    margin: fullScreenMode
      ? { l: 20, r: 20, t: 0, b: 0, autoexpand: false }
      : { l: 10, r: 0, t: 30, b: 10, autoexpand: false },
    // autosize: fullScreenMode ? true : false,
  };

  // Create traces for the 3D PMM plot
  const traces: Array<Scatter3dTrace | SurfaceTrace | ConeTrace> = [
    // Load points
    {
      type: "scatter3d",
      mode: "markers",
      x: loadCombos.map((combo) => combo.mx),
      y: loadCombos.map((combo) => combo.my),
      z: loadCombos.map((combo) => combo.p),
      marker: { color: "#002095", size: 4 },
    },
    // PMM surface
    {
      type: "surface",
      x: X,
      y: Y,
      z: Z,
      opacity: 0.5,
      colorscale: [
        [0, "#ffbb0f"],
        [1, "#ffbb0f"],
      ],
      showscale: false,
    },
    // PMM mesh lines (row and column lines)
    ...X.map(
      (_, i) =>
        ({
          type: "scatter3d",
          mode: "lines",
          line: { color: "#f7f7f7", width: 1.5 },
          x: X[i],
          y: Y[i],
          z: Z[i],
        } as Scatter3dTrace)
    ),
    ...X[0].map(
      (_, j) =>
        ({
          type: "scatter3d",
          mode: "lines",
          line: { color: "#f7f7f7", width: 1.5 },
          x: X.map((row) => row[j]),
          y: Y.map((row) => row[j]),
          z: Z.map((row) => row[j]),
        } as Scatter3dTrace)
    ),
  ];

  // Add axis arrows to traces
  ["x", "y", "z"].forEach((axis) => {
    createAxisArrow(axis as "x" | "y" | "z").forEach((arrow) => traces.push(arrow));
  });

  // Render plot with Plotly
  Plotly.react(elementId, traces, layout);
}
