import React, {useCallback, useState} from 'react';
import Graphin, {Behaviors, GraphinContext} from '@antv/graphin';
import {cutoffValueAtom, inputOutputSelectedAtom, selectedDatasetAtom, selectedSampleAtom} from "./settings";
import {atom, useAtom} from "jotai";

export const preprocessSample = (sample, layer) => {
    sample.nodes.forEach((node) => {
        node.style = {
            keyshape: {
                size: 20,
                stroke: colors[node["intermediate_outputs"][layer]],
                lineWidth: 1,
                fill: colors[node["intermediate_outputs"][layer]],
                fillOpacity: 0.5,
            },
            label: {
                value: node.nodeLabel || node.id,
                fontSize: 20,
                offset: [0, 12],
            },
        };
    });

    sample.edges.forEach(edge => {
        edge.style = {
            keyshape: {
                endArrow: true,
                startArrow: false
            }
        };
    });
}

const colors = ['yellow', 'blue', 'red', 'green', 'orange', 'purple', 'lime', 'cyan', '#f33eda', '#C8F3A9']

const {ZoomCanvas} = Behaviors;

export const hooveredNodeAtom = atom(-1)
export const selectedAtom = atom(undefined)
export const decisionPathAtom = atom([])

const GraphCard = (props) => {

    const {selectedLayer, onHooverNode, graphData, selectedTreePruning} = props
    const [state,] = React.useState({
        cutoffValueEdges: 0.01,
        showEdgeImportance: false,
    });

    const [, setShapleyValues] = useState("")
    const [cutoffValue] = useAtom(cutoffValueAtom)
    const [selectedDataSet] = useAtom(selectedDatasetAtom)
    const [currentSample] = useAtom(selectedSampleAtom)
    const [data, setData] = useState({...graphData.data[selectedTreePruning], edges: graphData.edges})
    const [hooveredNode, setHooveredNode] = useAtom(hooveredNodeAtom)
    const [selected, setSelected] = useAtom(selectedAtom)
    const [, setSelectedDecisionPath] = useAtom(decisionPathAtom)
    const [inputOutputSelected] = useAtom(inputOutputSelectedAtom);

    const [nodeClickSelected, setNodeClickSelected] = useState(false)

    const {cutoffValueEdges, showEdgeImportance} = state;
    React.useContext(GraphinContext);
    const graphRef = React.createRef(null);

    const highlightNode = useCallback(
        (n) => {
            if (graphRef.current) {
                let layerToUse = selectedLayer
                layerToUse = layerToUse > 0 ? layerToUse : 0
                const intermediateTopNodesScores = n["intermediate_top_nodes_score"][layerToUse]
                const intermediateTopNodesIndex = n["intermediate_top_nodes"][layerToUse]
                const nodeId = n.id
                const {graph} = graphRef.current;
                data.nodes.forEach(node => {
                    graph.updateItem(node.id, {
                        style: {
                            keyshape: {
                                size: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue ? 20 + Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) * 60 : 20,
                                lineWidth: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue ? 2 : 1,
                                opacity: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue ? 1.0 : 0.1,
                                fillOpacity: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue ? 0.8 : 0.1,
                            },
                            label: {
                                ...node.style.label,
                                value: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue ? intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))].toFixed(3) : node.id,
                                opacity: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue ? 1.0 : 0.1,
                                fill: intermediateTopNodesIndex.includes(parseInt(node.id)) && Math.abs(intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))]) > cutoffValue && intermediateTopNodesScores[intermediateTopNodesIndex.indexOf(parseInt(node.id))] < 0 ? 'red' : 'black',
                            },
                        },
                    });
                });
                setSelected(n)
                setData({
                    ...data
                });
                // console.log(data)
                setHooveredNode(data.nodes.findIndex(node => node.id === nodeId))
            }
        },
        [currentSample, inputOutputSelected, cutoffValue, cutoffValueEdges, data, graphRef, selectedDataSet, selectedLayer, setData, setHooveredNode, setSelected, showEdgeImportance],
    );

    const unHighlightNode = useCallback(
        (n) => {
            if (graphRef.current) {
                const {graph} = graphRef.current;
                const newGraphData = {...graphData.data[selectedTreePruning], edges: graphData.edges};
                newGraphData.nodes.forEach(node => {
                    graph.updateItem(node.id, {
                        style: {
                            keyshape: {
                                size: 20,
                                lineWidth: 1,
                                opacity: 1.0,
                                fillOpacity: 0.5,
                            },
                            label: {
                                ...node.style.label,
                                opacity: 1.0,
                                value: node.nodeLabel || node.id,
                                fill: 'black'
                            },
                        },
                    });
                });
                setSelected(n)
                setHooveredNode(-1);
            }
        },
        [graphData.nodes, graphRef, setHooveredNode, setSelected, selectedDataSet],
    );


    React.useEffect(() => {
        const {graph} = graphRef.current;
        const handleNodeClick = e => {
            if (!nodeClickSelected) {
                highlightNode(e.item.get('model'))
            }
        };
        graph.on('node:mouseenter', handleNodeClick);
        return () => {
            graph.off('node:mouseenter', handleNodeClick);
        };
    }, [graphRef, highlightNode]);

    React.useEffect(() => {
        const {graph} = graphRef.current;
        let layerToUse = inputOutputSelected === 'output' ? selectedLayer : selectedLayer - 1
        layerToUse = layerToUse > 0 ? layerToUse : 0
        const newGraphData = {...graphData.data[selectedTreePruning], edges: graphData.edges};
        preprocessSample(newGraphData, layerToUse);
        setData(newGraphData);
        graph.clear();
        graph.render();
        setHooveredNode(-1);
        setNodeClickSelected(false);
    }, [graphData, setData, setHooveredNode]);


    React.useEffect(() => {
        const {graph} = graphRef.current;
        const handleNodeClick = e => {
            if (e.select) {
                highlightNode(e.target.get('model'))
            } else {
                unHighlightNode(undefined)
            }
            setNodeClickSelected(e.select)
        };
        graph.on('nodeselectchange', handleNodeClick);
        return () => {
            graph.off('nodeselectchange', handleNodeClick);
        };
    }, [graphRef, highlightNode, unHighlightNode]);

    React.useEffect(() => {

        onHooverNode(hooveredNode)
        if (hooveredNode !== -1) {
            const newGraphData = {...graphData.data[selectedTreePruning], edges: graphData.edges};
            let shapValues = newGraphData.nodes[hooveredNode]["intermediate_features_used"][selectedLayer][newGraphData.nodes[hooveredNode]["intermediate_outputs"][selectedLayer]]
            shapValues = shapValues.map(function (each_element) {
                return Number(each_element.toFixed(2));
            });
            setShapleyValues(shapValues.join())
            const path = newGraphData.nodes[hooveredNode]["decision_paths"][selectedLayer]
            setSelectedDecisionPath(path)
        } else {
            setSelectedDecisionPath([])
        }
    }, [graphData, hooveredNode, selectedLayer, selectedTreePruning, onHooverNode]);


    React.useEffect(() => {
        const {graph} = graphRef.current;
        let layerToUse = inputOutputSelected === 'output' ? selectedLayer : selectedLayer - 1
        layerToUse = layerToUse > 0 ? layerToUse : 0
        const newGraphData = {...graphData.data[selectedTreePruning], edges: graphData.edges};
        preprocessSample(newGraphData, layerToUse);
        newGraphData.nodes.forEach(node => {
            graph.updateItem(node.id, {
                intermediate_top_nodes_score: node["intermediate_top_nodes_score"],
                intermediate_top_nodes: node["intermediate_top_nodes"],
                style: {
                    keyshape: {
                        stroke: colors[node["intermediate_outputs"][layerToUse]],
                        fill: colors[node["intermediate_outputs"][layerToUse]],
                        fillOpacity: 0.5,
                    },
                    label: {
                        ...node.style.label,
                        value: node.nodeLabel || node.id,
                        fontSize: 20,
                        offset: [0, 12],
                    },
                },
            });
        });
        if (nodeClickSelected) {
            highlightNode(selected)
        }
    }, [selectedLayer, selectedTreePruning, inputOutputSelected]);

    React.useEffect(() => {
        const {graph} = graphRef.current;
        const handleNodeLeave = e => {
            if (!nodeClickSelected) {
                unHighlightNode(e.item.get('model'));
            }
        };
        graph.on('node:mouseleave', handleNodeLeave);
        return () => {
            graph.off('node:mouseleave', handleNodeLeave);
        };
    }, [graphRef, nodeClickSelected, unHighlightNode]);

    return (

        <Graphin data={data} layout={{type: 'graphin-force', animation: false}} ref={graphRef} fitView
                 fitViewPadding={32} height="100%" theme='light'>
            <ZoomCanvas/>
        </Graphin>

    );
};

export default GraphCard;
