/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.graph;

import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.TextTable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class MisclassificationUtils {
    public static int getIndex(Endpoint endpoint) {
        if (endpoint == Endpoint.CIRCLE) {
            return 0;
        }
        if (endpoint == Endpoint.ARROW) {
            return 1;
        }
        if (endpoint == Endpoint.TAIL) {
            return 2;
        }
        if (endpoint == null) {
            return 3;
        }
        throw new IllegalArgumentException();
    }

    public static Set<Edge> convertNodes(Set<Edge> edges, List<Node> newVariables) {
        HashSet<Edge> newEdges = new HashSet<Edge>();
        EdgeListGraph convertedGraph = new EdgeListGraph(newVariables);
        for (Edge edge : edges) {
            Node node1 = convertedGraph.getNode(edge.getNode1().getName());
            Node node2 = convertedGraph.getNode(edge.getNode2().getName());
            if (node1 == null && !convertedGraph.containsNode(node1 = edge.getNode1())) {
                convertedGraph.addNode(node1);
            }
            if (node2 == null && !convertedGraph.containsNode(node2 = edge.getNode2())) {
                convertedGraph.addNode(node2);
            }
            Endpoint endpoint1 = edge.getEndpoint1();
            Endpoint endpoint2 = edge.getEndpoint2();
            Edge newEdge = new Edge(node1, node2, endpoint1, endpoint2);
            newEdges.add(newEdge);
        }
        return newEdges;
    }

    public static String endpointMisclassification(Graph estGraph, Graph refGraph) {
        List<Node> _nodes = refGraph.getNodes();
        estGraph = GraphUtils.replaceNodes(estGraph, _nodes);
        refGraph = GraphUtils.replaceNodes(refGraph, _nodes);
        _nodes = estGraph.getNodes();
        int[][] counts = new int[4][4];
        for (int i = 0; i < _nodes.size(); ++i) {
            for (int j = 0; j < _nodes.size(); ++j) {
                if (i == j) continue;
                Endpoint endpoint1 = refGraph.getEndpoint(_nodes.get(i), _nodes.get(j));
                Endpoint endpoint2 = estGraph.getEndpoint(_nodes.get(i), _nodes.get(j));
                int index1 = MisclassificationUtils.getIndex(endpoint1);
                int index2 = MisclassificationUtils.getIndex(endpoint2);
                int[] nArray = counts[index1];
                int n = index2;
                nArray[n] = nArray[n] + 1;
            }
        }
        TextTable table2 = new TextTable(5, 5);
        table2.setToken(0, 1, "-o");
        table2.setToken(0, 2, "->");
        table2.setToken(0, 3, "--");
        table2.setToken(0, 4, "no endpoint");
        table2.setToken(1, 0, "-o");
        table2.setToken(2, 0, "->");
        table2.setToken(3, 0, "--");
        table2.setToken(4, 0, "no endpoint");
        for (int i = 0; i < 4; ++i) {
            for (int j = 0; j < 4; ++j) {
                if (i == 3 && j == 3) {
                    table2.setToken(4, 4, "*");
                    continue;
                }
                table2.setToken(i + 1, j + 1, "" + counts[i][j]);
            }
        }
        return table2.toString();
    }

    public static String edgeMisclassifications(Graph estGraph, Graph refGraph) {
        int n;
        int m;
        Node y;
        Node x;
        estGraph = GraphUtils.replaceNodes(estGraph, refGraph.getNodes());
        StringBuilder builder = new StringBuilder();
        TextTable table2 = new TextTable(9, 7);
        table2.setToken(1, 0, "---");
        table2.setToken(2, 0, "o-o");
        table2.setToken(3, 0, "o->");
        table2.setToken(4, 0, "<-o");
        table2.setToken(5, 0, "-->");
        table2.setToken(6, 0, "<--");
        table2.setToken(7, 0, "<->");
        table2.setToken(8, 0, "no edge");
        table2.setToken(0, 1, "---");
        table2.setToken(0, 2, "o-o");
        table2.setToken(0, 3, "o->");
        table2.setToken(0, 4, "-->");
        table2.setToken(0, 5, "<->");
        table2.setToken(0, 6, "no edge");
        int[][] counts = new int[8][6];
        for (Edge est1 : estGraph.getEdges()) {
            x = est1.getNode1();
            Edge true1 = refGraph.getEdge(x, y = est1.getNode2());
            if (true1 == null) {
                true1 = new Edge(x, y, Endpoint.NULL, Endpoint.NULL);
            }
            Edge trueConvert = new Edge(x, y, true1.getProximalEndpoint(x), true1.getProximalEndpoint(y));
            m = MisclassificationUtils.getTypeLeft(trueConvert, est1);
            n = MisclassificationUtils.getTypeTop(est1);
            int[] nArray = counts[m];
            int n2 = n;
            nArray[n2] = nArray[n2] + 1;
        }
        for (Edge true1 : refGraph.getEdges()) {
            x = true1.getNode1();
            Edge est1 = estGraph.getEdge(x, y = true1.getNode2());
            if (est1 == null) {
                est1 = new Edge(x, y, Endpoint.NULL, Endpoint.NULL);
            }
            Edge estConvert = new Edge(x, y, est1.getProximalEndpoint(x), est1.getProximalEndpoint(y));
            m = MisclassificationUtils.getTypeLeft(true1, estConvert);
            n = MisclassificationUtils.getTypeTop(estConvert);
            if (n != 5) continue;
            int[] nArray = counts[m];
            nArray[5] = nArray[5] + 1;
        }
        for (int i = 0; i < 8; ++i) {
            for (int j = 0; j < 6; ++j) {
                if (i == 7 && j == 5) {
                    table2.setToken(8, 6, "*");
                    continue;
                }
                table2.setToken(i + 1, j + 1, "" + counts[i][j]);
            }
        }
        builder.append("\n").append(table2);
        return builder.toString();
    }

    private static int getTypeTop(Edge edgeTop) {
        if (edgeTop == null) {
            return 5;
        }
        Endpoint e1 = edgeTop.getEndpoint1();
        Endpoint e2 = edgeTop.getEndpoint2();
        if (e1 == Endpoint.TAIL && e2 == Endpoint.TAIL) {
            return 0;
        }
        if (e1 == Endpoint.CIRCLE && e2 == Endpoint.CIRCLE) {
            return 1;
        }
        if (e1 == Endpoint.CIRCLE && e2 == Endpoint.ARROW) {
            return 2;
        }
        if (e1 == Endpoint.ARROW && e2 == Endpoint.CIRCLE) {
            return 2;
        }
        if (e1 == Endpoint.TAIL && e2 == Endpoint.ARROW) {
            return 3;
        }
        if (e1 == Endpoint.ARROW && e2 == Endpoint.TAIL) {
            return 3;
        }
        if (e1 == Endpoint.ARROW && e2 == Endpoint.ARROW) {
            return 4;
        }
        return 5;
    }

    private static int getTypeLeft(Edge edgeLeft, Edge edgeTop) {
        if (edgeLeft == null) {
            return 7;
        }
        Endpoint e1 = edgeLeft.getEndpoint1();
        Endpoint e2 = edgeLeft.getEndpoint2();
        if (e1 == Endpoint.TAIL && e2 == Endpoint.TAIL) {
            return 0;
        }
        if (e1 == Endpoint.CIRCLE && e2 == Endpoint.CIRCLE) {
            return 1;
        }
        if (e1 == Endpoint.CIRCLE && e2 == Endpoint.ARROW && edgeTop.equals(edgeLeft.reverse())) {
            return 3;
        }
        if (e1 == Endpoint.CIRCLE && e2 == Endpoint.ARROW) {
            return 2;
        }
        if (e1 == Endpoint.TAIL && e2 == Endpoint.ARROW && edgeTop.equals(edgeLeft.reverse())) {
            return 5;
        }
        if (e1 == Endpoint.TAIL && e2 == Endpoint.ARROW) {
            return 4;
        }
        if (e1 == Endpoint.ARROW && e2 == Endpoint.ARROW) {
            return 6;
        }
        if (e1 == Endpoint.NULL && e2 == Endpoint.NULL) {
            return 7;
        }
        throw new IllegalArgumentException("Unsupported edge type : " + e1 + " " + e2);
    }
}

