import { Add, Remove, Replay } from '@mui/icons-material';
import FullscreenIcon from '@mui/icons-material/Fullscreen';
import { Box, ButtonGroup, IconButton, Stack } from '@mui/material';
import * as d3 from 'd3';
import d3Tip from 'd3-tip';
import { useEffect, useMemo, useRef } from 'react';
import './node-tree-graph.css';

const defaultConfig = {
  isPreview: false,
  toolTipEnabled: true,
  minHeight: 300,
  minWidth: 300,
  height: 300,
  width: 400,
  heightMultiplier: 0.1,
  nodeRadius: 15,
  nodePxHeight: 100,
  nodesNormalizeRatio: 100,
  textX: -10,
  textDy: -20,
  textSize: '1rem',
  textWeight: 600,
  tipPosition: 'nw',
  zoomEnabled: true,
  zoomFactor: 0.5,
  zoomMax: 5,
  zoomMin: 0.5
};

const NodeTreeGraph = ({ treeData, targetNode, config = {}, onExpandClick = {} }) => {
  const svgRef = useRef(null);
  const zoomRef = useRef(null);

  const computedConfig = useMemo(() => {
    return { ...defaultConfig, ...config };
  }, [config]);

  useEffect(() => {
    const thisSvgRef = svgRef.current;

    let maxNodeDepth = 0;

    let tree = d3.layout.tree().size([computedConfig.height, computedConfig.width]);
    const diagonal = d3.svg.diagonal().projection(function (d) {
      return [d.y, d.x];
    });

    const svg = d3
      .select(thisSvgRef)
      .attr('width', computedConfig.width)
      .attr('height', computedConfig.height)
      .attr('viewBox', '0 0 ' + computedConfig.width + ' ' + computedConfig.height)
      .append('g')
      .attr('class', 'container-g')
      .attr('transform', 'translate(75,0)');

    // ensure everything is loaded, then add the graph
    const root = treeData[0];
    root.x0 = computedConfig.height / 2;
    root.y0 = 0;

    let duration = d3.event && d3.event.altKey ? 2500 : 500;

    // compute the new height
    let levelWidth = [1];
    let childCount = function (level, n) {
      maxNodeDepth = Math.max(level, maxNodeDepth);
      if (n.children && n.children.length > 0) {
        if (levelWidth.length <= level + 1) levelWidth.push(0);

        levelWidth[level + 1] += n.children.length;
        n.children.forEach(function (d) {
          childCount(level + 1, d);
        });
      }
    };
    childCount(0, root);
    levelWidth.push(maxNodeDepth);

    const heightCalc = Math.max(
      d3.max(levelWidth) * computedConfig.nodePxHeight,
      computedConfig.minHeight
    );

    const viewBoxHeight = heightCalc + heightCalc * computedConfig.heightMultiplier;
    const viewBoxWidth = Math.max(
      computedConfig.nodePxHeight * 2 * maxNodeDepth,
      computedConfig.minWidth
    );

    d3.select(thisSvgRef)
      .attr('height', computedConfig.height)
      .attr('width', computedConfig.width)
      .attr('viewBox', `0 0 ${viewBoxWidth} ${viewBoxHeight}`);
    tree = tree.size([viewBoxHeight, viewBoxWidth]);
    tree.separation(function (a, b) {
      return a.parent === b.parent ? 0.5 : 0.5;
    });

    // Compute the new tree layout.
    let nodes = tree.nodes(root).reverse();
    let links = tree.links(nodes);

    // Normalize for fixed-depth.
    nodes.forEach(function (d) {
      d.y = d.depth * computedConfig.nodesNormalizeRatio;
    });

    // Update the nodes…
    let node = svg.selectAll('g.node').data(nodes, function (d) {
      return d.id;
    });

    // Enter any new nodes at the parent's previous position.
    let nodeEnter = node
      .enter()
      .append('g')
      .attr('class', function (d) {
        return (
          'node level-' +
          d.depth +
          (d.targetNode ? ' target-node' : '') +
          (computedConfig.isPreview ? ' preview-node' : '')
        );
      })
      .attr('transform', function (d) {
        return 'translate(' + root.y0 + ',' + root.x0 + ')';
      });

    if (computedConfig.zoomEnabled) {
      let zoomable_layer = d3.select(thisSvgRef).select('g.container-g');
      let zoom = d3.behavior
        .zoom()
        .scaleExtent([computedConfig.zoomMin, computedConfig.zoomMax])
        .on('zoom', function () {
          return zoomable_layer.attr({
            transform: 'translate(' + zoom.translate() + ') scale(' + zoom.scale() + ')'
          });
        });

      d3.select(thisSvgRef).call(zoom);
      zoomRef.current = zoom;
    }

    if (computedConfig.toolTipEnabled) {
      const tip = d3Tip()
        .attr('class', 'd3-tip')
        .html(function (d) {
          return d.desc;
        })
        .direction(function (d) {
          return computedConfig.tipPosition;
        })
        .offset([-10, -10]);

      svg.call(tip);

      nodeEnter
        .append('circle')
        .attr('r', computedConfig.nodeRadius)
        .style('fill', function (d) {
          return d._children ? 'lightsteelblue' : '#fff';
        })
        //   .on('click', handleClick)
        .on('mouseover', function (d) {
          if (d && d.desc) {
            tip.attr('class', 'd3-tip animate').show(d, this);
          }
        })
        .on('mouseout', function (d) {
          tip.attr('class', 'd3-tip').show(d, this);
          tip.hide();
        });
    } else {
      nodeEnter
        .append('circle')
        .attr('r', computedConfig.nodeRadius)
        .style('fill', function (d) {
          return d._children ? 'lightsteelblue' : '#fff';
        });
    }

    nodeEnter
      .append('text')
      .attr('x', function (d) {
        return d.children || d._children ? computedConfig.textX * -1 : computedConfig.textX;
      })
      .attr('dy', computedConfig.textDy)
      .attr('text-anchor', function (d) {
        return d.children || d._children ? 'end' : 'start';
      })
      .text(function (d) {
        return d.name;
      })
      .style('fill-opacity', 1e-6)
      .style('font-size', computedConfig.textSize)
      .style('font-weight', computedConfig.textWeight);
    // .on('click', handleTextClick);

    // Transition nodes to their new position.
    let nodeUpdate = node
      .transition()
      .duration(duration)
      .attr('transform', function (d) {
        return 'translate(' + d.y + ',' + d.x + ')';
      });

    nodeUpdate
      .select('circle')
      .attr('r', computedConfig.nodeRadius)
      .style('fill', function (d) {
        let isTarget = d.name === targetNode;
        return isTarget ? 'lightblue' : d._children ? 'lightsteelblue' : '#fff';
      });

    nodeUpdate.select('text').style('fill-opacity', 1);

    // Transition exiting nodes to the parent's new position.
    let nodeExit = node
      .exit()
      .transition()
      .duration(duration)
      .attr('transform', function (d) {
        return 'translate(' + root.y + ',' + root.x + ')';
      })
      .remove();

    nodeExit.select('circle').attr('r', computedConfig.nodeRadius);

    nodeExit.select('text').style('fill-opacity', 1e-6);

    // Update the links…
    let link = svg.selectAll('path.link').data(links, function (d) {
      return d.target.id;
    });

    // Enter any new links at the parent's previous position.
    link
      .enter()
      .insert('path', 'g')
      .attr('class', function (d) {
        return 'link level-' + d.source.depth;
      })
      .attr('d', function (d) {
        let o = {
          x: root.x0,
          y: root.y0
        };
        return diagonal({
          source: o,
          target: o
        });
      });

    // Transition links to their new position.
    link.transition().duration(duration).attr('d', diagonal);

    // Transition exiting nodes to the parent's new position.
    link
      .exit()
      .transition()
      .duration(duration)
      .attr('d', function (d) {
        let o = {
          x: root.x,
          y: root.y
        };
        return diagonal({
          source: o,
          target: o
        });
      })
      .remove();

    // Stash the old positions for transition.
    nodes.forEach(function (d) {
      d.x0 = d.x;
      d.y0 = d.y;
    });

    return () => {
      // TODO remove this, update data using d3 only, likely with another useEffect
      d3.selectAll('.d3-tip').remove();
      d3.select(thisSvgRef).selectAll('*').remove();
    };
  }, [computedConfig, treeData, targetNode]);

  const onResetZoom = () => {
    zoomRef.current.scale(1);
    zoomRef.current.translate([75, 0]);

    d3.select(svgRef.current)
      .select('g.container-g')
      .attr(
        'transform',
        'translate(' + zoomRef.current.translate() + ') scale(' + zoomRef.current.scale() + ')'
      );
  };

  return (
    <Box
      sx={{
        height: '100%',
        width: '100%',
        display: 'flex',
        alignItems: 'center',
        justifyContent: 'center',
        cursor: 'grab',
      }}
    >
      <svg className="svg-graph" ref={svgRef}></svg>
      <Stack spacing={2} sx={{ alignSelf: 'flex-start', justifySelf: 'flex-end' }}>
        {computedConfig.isPreview && (
          <IconButton
            color="primary"
            sx={{ '& svg': { width: '1.75rem', height: '1.75rem' } }}
            onClick={onExpandClick}
          >
            <FullscreenIcon />
          </IconButton>
        )}
        <ButtonGroup orientation="vertical" aria-label="Vertical button group">
          <IconButton
            color="primary"
            sx={{ '& svg': { width: '1.75rem', height: '1.75rem' } }}
            onClick={() => {
              zoomRef.current.scale(
                Math.min(
                  zoomRef.current.scale() + computedConfig.zoomFactor,
                  computedConfig.zoomMax
                )
              );
              if (zoomRef.current.scale() < computedConfig.zoomMax) {
                const translate = [...zoomRef.current.translate()];
                zoomRef.current.translate([translate[0] - 30, translate[1] - 30]);
              }

              d3.select(svgRef.current)
                .select('g.container-g')
                .attr(
                  'transform',
                  'translate(' +
                    zoomRef.current.translate() +
                    ') scale(' +
                    zoomRef.current.scale() +
                    ')'
                );
            }}
          >
            <Add />
          </IconButton>
          <IconButton
            color="primary"
            sx={{ '& svg': { width: '1.75rem', height: '1.75rem' } }}
            onClick={() => {
              zoomRef.current.scale(
                Math.max(
                  zoomRef.current.scale() - computedConfig.zoomFactor,
                  computedConfig.zoomMin
                )
              );
              if (zoomRef.current.scale() > computedConfig.zoomMin) {
                const translate = [...zoomRef.current.translate()];
                zoomRef.current.translate([translate[0] + 15, translate[1] + 15]);
              }

              d3.select(svgRef.current)
                .select('g.container-g')
                .attr(
                  'transform',
                  'translate(' +
                    zoomRef.current.translate() +
                    ') scale(' +
                    zoomRef.current.scale() +
                    ')'
                );
            }}
          >
            <Remove />
          </IconButton>
          <IconButton
            color="primary"
            sx={{ '& svg': { width: '1.75rem', height: '1.75rem' } }}
            onClick={onResetZoom}
          >
            <Replay />
          </IconButton>
        </ButtonGroup>
      </Stack>
    </Box>
  );
};

export default NodeTreeGraph;
