import { Theme, alpha } from "@mui/material";
import { NodeObject } from "react-force-graph-2d";
import { INetworkGraphNode, NetworkGraphNodeType } from "../../models";
import { NODE_RADIUS, NODE_LABEL_Y_OFFSET } from "./constants";

const DELAY_FACTOR = 30;
const FADE_DURATION = 40;

export function drawNode(
  node: NodeObject<INetworkGraphNode>,
  context: CanvasRenderingContext2D,
  scale: number,
  theme: Theme,
  drawImage: boolean,
  haloRadius?: number,
  opacity?: number
) {
  if (node.x === undefined || node.y === undefined) {
    return;
  }

  // Draw halo for node if specified
  if (haloRadius && haloRadius > 1) {
    context.beginPath();
    context.arc(node.x, node.y, haloRadius, 0, 2 * Math.PI, true);
    context.fillStyle = getNodeColor(context, theme, node, opacity ? (opacity < 0.2 ? opacity : 0.2) : 0);
    context.fill();
  }

  // Draw coloured circle
  context.beginPath();
  context.arc(node.x, node.y, NODE_RADIUS, 0, 2 * Math.PI, true);
  context.fillStyle = getNodeColor(context, theme, node, opacity ?? 1);
  context.fill();

  // Draw image
  if (drawImage && node.profilePhotoSmall != null) {
    // Draw white border
    context.beginPath();
    context.arc(node.x, node.y, NODE_RADIUS - 2, 0, 2 * Math.PI, true);
    context.fillStyle = alpha("#fff", opacity ?? 1);
    context.fill();

    // Create clipping region
    context.save();
    context.beginPath();
    context.arc(node.x, node.y, NODE_RADIUS - 4, 0, 2 * Math.PI, true);
    context.clip();

    const imageWidth = NODE_RADIUS * 2;
    const imageHeight = NODE_RADIUS * 2;
    const imageX = node.x - NODE_RADIUS;
    const imageY = node.y - NODE_RADIUS;
    context.globalAlpha = opacity ?? 1;
    context.drawImage(node.profilePhotoSmall, imageX, imageY, imageWidth, imageHeight);

    // Restore context to remove clipping region
    context.restore();
  }

  // Draw initials
  if (node.type === NetworkGraphNodeType.Member && (node.profilePhotoSmall == null || !drawImage)) {
    const firstInitial = node.label.split(" ")[0][0];
    const lastInitial = node.label.split(" ")[1]?.[0] ?? "";

    context.fillStyle = alpha(theme.palette.common.white, opacity ?? 1);
    context.font = "16px 'OpenSans', sans-serif";
    context.textAlign = "center";
    context.textBaseline = "middle";
    context.fillText(`${firstInitial}${lastInitial}`, node.x, node.y + 2);
  }

  // Draw label
  context.fillStyle = alpha(theme.palette.common.black, opacity ?? 1);
  context.font = "10px 'Open Sans', sans-serif";
  context.textAlign = "center";
  context.textBaseline = "bottom";
  context.fillText(node.label.toUpperCase(), node.x, node.y + NODE_RADIUS + NODE_LABEL_Y_OFFSET);
}

function getNodeColor(
  context: CanvasRenderingContext2D,
  theme: Theme,
  node: NodeObject<INetworkGraphNode>,
  opacity: number
) {
  switch (node.type) {
    case NetworkGraphNodeType.RootTeam: {
      return alpha(theme.palette.primary.main, opacity);
    }
    case NetworkGraphNodeType.ParentTeam: {
      return alpha("#cbe1e3", opacity);
    }
    case NetworkGraphNodeType.ChildTeam: {
      return alpha(theme.palette.primary.main, opacity);
    }
    case NetworkGraphNodeType.Role: {
      return alpha(theme.palette.orange.main, opacity);
    }
    case NetworkGraphNodeType.Member: {
      const gradientStartX = node.x! - NODE_RADIUS;
      const gradientStartY = node.y! - NODE_RADIUS;
      const gradientEndX = gradientStartX + NODE_RADIUS * 2;
      const gradientEndY = gradientStartY + NODE_RADIUS * 2;

      const gradient = context.createLinearGradient(gradientStartX, gradientStartY, gradientEndX, gradientEndY);

      gradient.addColorStop(0, alpha(theme.palette.error.main, opacity));
      gradient.addColorStop(1, alpha(theme.palette.wine.main, opacity));

      return gradient;
    }
  }
}

export function calculateOpacity(currentFrame: number, nodeDepth: number) {
  let delay = DELAY_FACTOR * nodeDepth;

  if (currentFrame < delay) {
    return 0;
  }

  if (currentFrame >= delay && currentFrame <= delay + FADE_DURATION) {
    return (currentFrame - delay) / FADE_DURATION;
  } else if (currentFrame > delay + FADE_DURATION) {
    return 1;
  }
}
