import dagreD3 from 'dagre-d3';
import React, { useEffect, useRef } from 'react';
import * as d3 from 'd3';
import { useTheme } from '@mui/system';
import styles from './workflow.module.css';
import { TaskDefinition, TaskType } from '../../types/task';
import { TaskAndProgress } from '../../types/workflow';

export interface FlowGraphProps {
  tasks: TaskAndProgress[];
  onEvent?: (event: FlowGraphEvent) => void;
}

export interface FlowGraphEvent {
  nodeId: string;
}

/* eslint-disable @typescript-eslint/no-explicit-any */

const WorkflowGraph = ({ tasks, onEvent }: FlowGraphProps) => {
  const svgRef = useRef<SVGSVGElement>(null);
  const theme = useTheme();

  const isDarkMode = () => theme.palette.mode === 'dark';

  const centerGraph = (svgGroup: any, graph: any) => {
    const graphWidth = (graph.graph() as any).width;
    const graphHeight = (graph.graph() as any).height;
    const svgWidth = svgRef.current.clientWidth;
    const svgHeight = svgRef.current.clientHeight;
    const xCenterOffset = (svgWidth - graphWidth) / 2;
    const yCenterOffset = (svgHeight - graphHeight) / 2;
    svgGroup.attr('transform', `translate(${xCenterOffset}, ${yCenterOffset})`);
  };

  const scaleGraphToFit = (svgGroup: any, graph: any, zoom: any, svg: any) => {
    const graphWidth = (graph.graph() as any).width;
    const graphHeight = (graph.graph() as any).height;
    const svgWidth = svgRef.current.clientWidth;
    const svgHeight = svgRef.current.clientHeight;
    const scale = Math.min(1, Math.min(svgWidth / graphWidth, svgHeight / graphHeight));
    const xCenterOffset = (svgWidth - graphWidth * scale) / 2;
    const yCenterOffset = (svgHeight - graphHeight * scale) / 2;
    const initialTransform = d3.zoomIdentity.translate(xCenterOffset, yCenterOffset).scale(scale);
    svg.call(zoom.transform, initialTransform);
  };

  const roundCorners = (node: any) => {
    node.rx = node.ry = 3;
  };

  const handleNodeClick = (nodeId: string) => {
    if (onEvent) {
      onEvent({ nodeId: nodeId });
    }
  };

  const getNodeClass = (node: TaskDefinition) => {
    const type = node.type.toLowerCase().replace(/_/g, '-');
    const state = node.status.toLowerCase().replace(/_/g, '-');

    if (node.type == TaskType.NOT_YET_DEFINED) {
      return `${styles.node} ${styles[type]}`;
    }

    return `${styles.node} ${styles[type]} ${type} ${styles[state]}`;
  };

  const getNodeShape = (type: TaskType) => {
    switch (type) {
      case TaskType.TASK:
        return 'rect';
      case TaskType.GROUP:
        return 'rect';
      case TaskType.DECISION:
        return 'diamond';
      case TaskType.START:
      case TaskType.END:
        return 'circle';
      case TaskType.JOIN:
      case TaskType.NOT_YET_DEFINED:
      default:
        return 'rect';
    }
  };

  useEffect(() => {
    if (!tasks || tasks.length == 0 || (tasks.length == 1 && tasks[0].name === 'Empty Flow')) {
      return;
    }
    const g = new dagreD3.graphlib.Graph().setGraph({ compound: true }).setDefaultEdgeLabel(() => ({}));

    const darkClass = isDarkMode() ? styles.dark : '';

    tasks.forEach((node) => {
      g.setNode(node.taskId, {
        label: node.name,
        shape: getNodeShape(node.type),
        class: `${getNodeClass(node)} ${darkClass}`,
        data: { nodeId: node.taskId }
      });
    });

    tasks.forEach((node) => {
      if (node.children && node.children.length > 0) {
        node.children.forEach((child) => {
          g.setEdge(node.taskId, child.taskId, {
            label: child.name || '',
            curve: d3.curveBasis,
            style: 'fill: none',
            class: styles.edgePath,
            arrowheadClass: styles.edgeArrowhead
          });
        });
      }
    });
    tasks.forEach((node) => {
      roundCorners(g.node(node.taskId));
    });

    const render = new (dagreD3 as any).render();
    const svg = d3.select(svgRef.current);
    svg.selectAll('*').remove();
    const svgGroup = svg.append('g');

    render(svgGroup, g);

    const zoom = d3.zoom().on('zoom', (event) => {
      svgGroup.attr('transform', event.transform);
    });
    svg.call(zoom);

    svgGroup.selectAll('g.node').on('click', (event: PointerEvent) => {
      const clickedElement: any = event.currentTarget;
      const nodeId = (d3.select(clickedElement) as any).datum();
      handleNodeClick(nodeId);
    });

    centerGraph(svgGroup, g);
    scaleGraphToFit(svgGroup, g, zoom, svg);
  }, [tasks, theme]);

  return (
    <svg
      ref={svgRef}
      width={'100%'}
      height={'100%'}
      className={isDarkMode() ? 'flowgraphSvgDark' : 'flowgraphSvg'}
    ></svg>
  );
};

export default WorkflowGraph;
