import jpype
import jpype.imports

try:
    # please change to the abs path of tetrad-gui-7.2.2-launch.jar
    jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
except OSError:
    print('JVM already started')

import java.util
import edu.cmu.tetrad.data as td
import edu.cmu.tetrad.graph as tg
import edu.cmu.tetrad.algcomparison.statistic as tas

import numpy as np
import networkx as nx


def numpy_data_to_tetrad(data):
    variables = java.util.ArrayList()
    for col in range(data.shape[1]):
        variables.add(td.ContinuousVariable(str(col)))

    databox = td.DoubleDataBox(
        jpype.JArray(jpype.JDouble, 2)(data.tolist())
    )
    return td.BoxDataSet(databox, variables)


def tetrad_graph_to_networkx(t_graph):
    nx_graph = nx.DiGraph()
    nx_graph.add_nodes_from(str(node.getName()) for node in t_graph.getNodes())

    for edge in t_graph.getEdges():
        n1, n2 = str(edge.getNode1().getName()), str(edge.getNode2().getName()),
        ep1, ep2 = edge.getEndpoint1().name(), edge.getEndpoint2().name()

        if ep1 == 'TAIL' and ep2 == 'ARROW':
            nx_graph.add_edge(n1, n2)
        elif ep1 == 'TAIL' and ep2 == 'TAIL':
            nx_graph.add_edge(n1, n2)
            nx_graph.add_edge(n2, n1)
        else:
            print('something wrong... {}, {}, {}, {}'.format(n1, n2, ep1, ep2))
    return nx_graph


def networkx_graph_to_tetrad(nx_graph, t_template=None):
    node_mapping = {}
    node_list = java.util.ArrayList()
    if t_template is None:
        for node in nx_graph.nodes:
            node_mapping[node] = tg.GraphNode(str(node))
            node_list.add(node_mapping[node])
    else:
        node_list = t_template.getNodes()
        for node in node_list:
            node_mapping[str(node.getName())] = node
    
    t_graph = tg.EdgeListGraph(node_list)
    for node1, node2 in nx_graph.edges:
        if nx_graph.has_edge(node2, node1):
            if node1 < node2:
                t_graph.addUndirectedEdge(node_mapping[node1], node_mapping[node2])
        else:
            t_graph.addDirectedEdge(node_mapping[node1], node_mapping[node2])
    return t_graph


def statistic(true_graph, est_graph):
    return {
        'f1_adj': tas.F1Adj().getValue(true_graph, est_graph, None),
        'f1_arrow': tas.F1Arrow().getValue(true_graph, est_graph, None)
    }
