import networkx as nx
import numpy as np
import numpy.random as rand
from pathlib import Path
import sys
from typing import List
from causalAssembly.models_dag import ProductionLineGraph

import utils


# maybe use enum for graph type
class CausalGraph:
    dir_neighbors: List[List[int]]
    undir_neighbors: List[List[int]]
    graph_type: str

    def __init__(self, p, dir_edges, undir_edges, graph_type) -> None:
        self.dir_neighbors = [[] for _ in range(p)]
        self.undir_neighbors = [[] for _ in range(p)]
        for a, b in dir_edges:
            self.dir_neighbors[a].append(b)
        for a, b in undir_edges:
            self.undir_neighbors[a].append(b)
            self.undir_neighbors[b].append(a)
        self.graph_type = graph_type

    def __eq__(self, other):
        p = self.num_nodes()
        if p != other.num_nodes():
            return False
        if self.num_edges() != other.num_edges():
            return False
        if self.graph_type != other.graph_type:
            return False
        for u in range(p):
            if set(self.undir_neighbors[u]) != set(other.undir_neighbors[u]) or set(
                self.dir_neighbors[u]
            ) != set(other.dir_neighbors[u]):
                return False
        return True

    def __str__(self):
        return str(self.to_edgelist())

    def num_nodes(self) -> int:
        return len(self.dir_neighbors)

    def num_edges(self) -> int:
        return (
            sum(len(adj) for adj in self.dir_neighbors)
            + sum(len(adj) for adj in self.undir_neighbors) // 2
        )

    def to_edgelist(self):
        p = self.num_nodes()
        edge_list = []
        for u in range(p):
            for v in self.dir_neighbors[u]:
                edge_list.append((u, v, "directed"))
            for v in self.undir_neighbors[u]:
                if u < v:
                    edge_list.append((u, v, "undirected"))
        return edge_list

    def to_matrix(self):
        p = self.num_nodes()
        gm = np.zeros((p, p), dtype=np.int8)
        for u in range(p):
            for v in self.dir_neighbors[u]:
                gm[u, v] = 1
            for v in self.undir_neighbors[u]:
                gm[u, v] = 2
                gm[v, u] = 2
        return gm

    def write(self, file=sys.stdout):
        n = self.num_nodes()
        m = self.num_edges()
        print(n, m, self.graph_type, file=file)
        edge_list = self.to_edgelist()
        for u, v, edge_type in edge_list:
            print(u, v, edge_type, file=file)

    def write_to_file(self, file_path):
        with open(file_path, "w") as file:
            self.write(file)

    def is_valid(self):
        if self.graph_type == "dag":
            if not self.is_acyclic():
                return False

        if self.graph_type == "cpdag":
            dag = self.to_dag()
            if not dag.is_acyclic():
                return False
            g = dag.to_cpdag()
            if self != g:
                return False
        return True

    def transpose_graph(self):
        p = self.num_nodes()
        dir_edges = []
        undir_edges = []
        for u in range(p):
            for v in self.dir_neighbors[u]:
                dir_edges.append((v, u))
            for v in self.undir_neighbors[u]:
                undir_edges.append((u, v))
        return CausalGraph(p, dir_edges, undir_edges, self.graph_type)

    def topological_order_dfs(self, vis, order, u):
        if vis[u]:
            return
        vis[u] = True
        for v in self.dir_neighbors[u]:
            self.topological_order_dfs(vis, order, v)
        order.append(u)

    def topological_order(self):
        if self.graph_type != "dag":
            raise ValueError("method expected DAG")
        p = self.num_nodes()
        vis = [False] * p
        order = []
        for u in range(p):
            if not vis[u]:
                self.topological_order_dfs(vis, order, u)
        return list(reversed(order))

    def is_acyclic(self):
        p = self.num_nodes()
        to = self.topological_order()
        inv_to = np.argsort(to)
        for u in range(p):
            for v in self.dir_neighbors[u]:
                if inv_to[u] > inv_to[v]:
                    return False
        return True

    def to_dag(self):
        if self.graph_type == "dag":
            return self

        p = self.num_nodes()
        undir_subgraph = ChordalGraph(self.undir_neighbors)
        mcs_order = undir_subgraph.mcs()
        inv_mcs_order = np.argsort(mcs_order)

        dir_edges = []
        for u in range(p):
            for v in self.undir_neighbors[u]:
                if inv_mcs_order[u] < inv_mcs_order[v]:
                    dir_edges.append((u, v))

        for u in range(p):
            for v in self.dir_neighbors[u]:
                dir_edges.append((u, v))

        return CausalGraph(p, dir_edges, [], "dag")

    def to_cpdag(self):
        if self.graph_type == "cpdag":
            return self
        p = self.num_nodes()
        tg = self.transpose_graph()
        for parents in tg.dir_neighbors:
            parents.sort()

        to = self.topological_order()
        inv_to = np.argsort(to)

        is_compelled = [False] * p
        compelled_ingoing = [[] for _ in range(p)]

        dir_edges = []
        undir_edges = []

        for y in to:
            if not tg.dir_neighbors[y]:
                continue

            for u in tg.dir_neighbors[y]:
                is_compelled[u] = False

            x = max(tg.dir_neighbors[y], key=lambda x: inv_to[x])

            done = False
            for w in compelled_ingoing[x]:
                if not utils.binary_search(tg.dir_neighbors[y], w):
                    for u in tg.dir_neighbors[y]:
                        is_compelled[u] = True
                    done = True
                    break
                else:
                    is_compelled[w] = True

            if not done:
                for z in tg.dir_neighbors[y]:
                    if z == x:
                        continue
                    if not utils.binary_search(tg.dir_neighbors[x], z):
                        for u in tg.dir_neighbors[y]:
                            is_compelled[u] = True
                        break

            for v in tg.dir_neighbors[y]:
                if is_compelled[v]:
                    compelled_ingoing[y].append(v)
                    dir_edges.append((v, y))
                else:
                    undir_edges.append((v, y))
        return CausalGraph(p, dir_edges, undir_edges, "cpdag")


def read(file=sys.stdout):
    tokens = file.readline().strip().split()
    n = int(tokens[0])
    graph_type = tokens[2]
    dir_edges = []
    undir_edges = []
    for line in file:
        tokens = line.strip().split()
        a = int(tokens[0])
        b = int(tokens[1])
        # if no direction is given assume "directed"
        # for easier integration of DAGs such as the bnlearn instances
        dir = "directed"
        if len(tokens) > 2:
            dir = tokens[2]
        if dir == "directed":
            dir_edges.append((a, b))
        else:
            undir_edges.append((a, b))
    return CausalGraph(n, dir_edges, undir_edges, graph_type)


def read_from_file(file_path):
    with open(file_path, "r") as file:
        return read(file)


def sample(id, rng=rand.default_rng()):
    tokens = id.split("-")
    graph_type = tokens[0]
    if graph_type == "er":
        num_nodes, avg_degree = int(tokens[1]), int(tokens[2])
        g = sample_er_dag(num_nodes, avg_degree, rng)
    elif graph_type == "sf":
        num_nodes, avg_degree = int(tokens[1]), int(tokens[2])
        g = sample_sf_dag(num_nodes, avg_degree // 2, rng)
    elif graph_type == "bnlearn":
        name = tokens[1]
        g = get_bnlearn(name)
    elif graph_type == "chain":
        num_nodes = int(tokens[1])
        g = get_chain(num_nodes)
    elif graph_type == "sachs":
        g = read_from_file("external_data/sachs_graph.txt")
    elif graph_type == "causalAssembly":
        assembly_line_graph = (
            ProductionLineGraph.get_ground_truth().ground_truth.to_numpy()
        )
        dir_edges = []
        p, _ = assembly_line_graph.shape
        for u in range(p):
            for v in range(p):
                if assembly_line_graph[u, v] == 1:
                    dir_edges.append((u, v))
        return CausalGraph(p, dir_edges, [], "dag")
    elif graph_type == "syntren":
        print("hio")
        mat = np.load("external_data/syntren/DAG1.npy")
        dir_edges = []
        p, _ = mat.shape
        # TODO: check matrix orientation!!!
        for u in range(p):
            for v in range(p):
                if mat[v, u] == 1:
                    dir_edges.append((u, v))
        return CausalGraph(p, dir_edges, [], "dag")
    else:
        raise ValueError(f"graph sampler {graph_type} not supported")
    return g


def sample_er_dag(num_nodes, avg_degree, rng=rand.default_rng()):
    p = avg_degree / (num_nodes - 1)
    dir_edges = []
    order = rand.permutation(num_nodes)

    for u in range(num_nodes):
        for v in range(u + 1, num_nodes):
            if rng.random() < p:
                dir_edges.append((order[u], order[v]))

    return CausalGraph(num_nodes, dir_edges, [], "dag")


def sample_sf_dag(num_nodes, m, rng=rand.default_rng()):
    g = nx.barabasi_albert_graph(num_nodes, m, rng)
    dir_edges = []
    order = rand.permutation(num_nodes)

    for u, v in g.edges:
        dir_edges.append((order[u], order[v]))

    return CausalGraph(num_nodes, dir_edges, [], "dag")


def get_bnlearn(name):
    return read_from_file("bnlearn/" + name + ".net")


def get_chain(num_nodes: int):
    dir_edges = []
    for u in range(num_nodes - 1):
        dir_edges.append((u, u + 1))
    return CausalGraph(num_nodes, dir_edges, [], "dag")


# for now a thin wrapper around adjacency list
class ChordalGraph:
    neighbors: List[List[int]]

    def __init__(self, neighbors: List[List[int]]) -> None:
        self.neighbors = neighbors

    def mcs(self):
        p = len(self.neighbors)
        ordering = []
        sets = [[] for _ in range(p)]
        cardinality = [0] * p
        max_cardinality = 0

        sets[0] = list(range(p))

        idx = 0
        while idx < p:
            while max_cardinality > 0 and not sets[max_cardinality]:
                max_cardinality -= 1

            u = sets[max_cardinality].pop()
            if cardinality[u] < 0:
                continue

            idx += 1

            ordering.append(u)

            cardinality[u] = -1

            for v in self.neighbors[u]:
                if cardinality[v] >= 0:
                    cardinality[v] += 1
                    sets[cardinality[v]].append(v)

            max_cardinality += 1

        return ordering
