import { ChevronRightRounded, ExpandMoreRounded } from "@mui/icons-material";
import { TreeView as MuiTreeView } from "@mui/x-tree-view";
import { ITreeNode } from "../models";
import { TreeViewItem } from "./TreeViewItem";

interface IProps {
  nodes: ITreeNode[];
  selectedNodeIds: string[];
  onNodeSelected: (selectedNodes: ITreeNode[]) => void;
}

const TreeView = (props: IProps) => {
  function getSelectedNodeIdAndDescendantNodeIds(nodeId: string, nodes: ITreeNode[]): string[] {
    var selectedNode = nodes.find((x) => x.nodeId === nodeId);

    if (selectedNode) {
      return getDescendantTreeNodes(selectedNode, nodes).map((x) => x.nodeId);
    } else {
      return [];
    }
  }

  function getDescendantTreeNodes(node: ITreeNode, nodes: ITreeNode[]): ITreeNode[] {
    var childNodes = nodes.filter((x) => x.parentNodeId === node.nodeId);

    if (childNodes?.length === 0) {
      return [node];
    } else {
      return [node, ...(childNodes?.flatMap((x) => getDescendantTreeNodes(x, nodes)) ?? [])];
    }
  }

  function getDescendantTreeItems(node: ITreeNode, nodes: ITreeNode[], selectedNodeIds: string[]) {
    const children = nodes.filter((x) => x.parentNodeId === node.nodeId);

    return (
      <TreeViewItem
        key={node.nodeId}
        nodeId={node.nodeId}
        label={node.label}
        nodes={nodes}
        selectedNodeIds={selectedNodeIds}
        onNodeSelected={(nodeId) => {
          var selectedNodeIdAndDescendantNodeIds = getSelectedNodeIdAndDescendantNodeIds(nodeId, nodes);

          if (selectedNodeIds.includes(nodeId)) {
            // Remove node and its descendants.
            const newSelectedNodeIds = selectedNodeIds.filter((x) => !selectedNodeIdAndDescendantNodeIds.includes(x));
            const result = nodes.filter((x) => newSelectedNodeIds.includes(x.nodeId));

            props.onNodeSelected(result);
          } else {
            // Add node and its descendants
            const newSelectedNodeIds = [...selectedNodeIds, ...selectedNodeIdAndDescendantNodeIds];
            const result = nodes.filter((x) => newSelectedNodeIds.includes(x.nodeId));

            props.onNodeSelected(result);
          }
        }}
      >
        {children.map((x) => getDescendantTreeItems(x, nodes, selectedNodeIds))}
      </TreeViewItem>
    );
  }

  return (
    <MuiTreeView
      defaultCollapseIcon={<ExpandMoreRounded />}
      defaultExpandIcon={<ChevronRightRounded />}
      multiSelect
      selected={props.selectedNodeIds}
    >
      {props.nodes
        .filter((x) => x.parentNodeId == null)
        .sort((a, b) => a.label.localeCompare(b.label))
        .map((x) => getDescendantTreeItems(x, props.nodes, props.selectedNodeIds))}
    </MuiTreeView>
  );
};

export { TreeView };
