# NOTE: we should either keep this or use 'causally' for data generation. Remove one.

from typing import Iterable, Tuple, Any

import cdt
import networkx as nx
import numpy as np
import torch
from cdt.data.causal_mechanisms import gaussian_cause
from torch.utils.data import Dataset


class TrivialGraphDataset(Dataset):
    def __init__(self, num_datasets: int = 1000, num_nodes=4, num_samples=100, mechanism: str = 'linear',
                 noise: str = 'uniform'):
        self.num_nodes = num_nodes
        self.num_samples = num_samples
        self.mechanism = mechanism
        self.noise = noise
        self.cache = []
        self.num_datasets = num_datasets

        self.generator = cdt.data.AcyclicGraphGenerator(self.mechanism, self.noise, nodes=self.num_nodes,
                                                   npoints=self.num_samples,
                                                   noise_coeff=1.,
                                                   # It seems like there is a bug in cdt and the graph is too dense
                                                   # if we don't divide by two
                                                   expected_degree=2 / 2., dag_type='erdos',
                                                   initial_variable_generator=gaussian_cause)

    def __len__(self):
        return self.num_datasets

    def __getitem__(self, idx):
        while len(self.cache) <= idx:
            data, ground_truth = self.generator.generate()
            adj_matrix = nx.adjacency_matrix(ground_truth).todense()
            self.cache.append((torch.tensor(data.to_numpy()).float(), torch.tensor(adj_matrix).flatten().float()))

        if idx > 1:
            assert torch.all(self.cache[idx][1].eq(self.cache[idx-1][1]))

        return self.cache[idx]


class GraphDatasetGenerator(Dataset):
    def __init__(self, num_graphs=1000, num_nodes=4, num_samples=100, test=False, mechanism: str = 'linear',
                 noise: str = 'uniform'):
        self.num_graphs = num_graphs if not test else int(num_graphs / 10)
        self.num_nodes = num_nodes
        self.num_samples = num_samples
        self.mechanism = mechanism
        self.noise = noise
        self.cache = []

    def __len__(self):
        return self.num_graphs

    def __getitem__(self, idx):
        while len(self.cache) <= idx:
            generator = cdt.data.AcyclicGraphGenerator(self.mechanism, self.noise, nodes=self.num_nodes,
                                                       npoints=self.num_samples,
                                                       noise_coeff=1.,
                                                       # It seems like there is a bug in cdt and the graph is too dense
                                                       # if we don't divide by two
                                                       expected_degree=2 / 2., dag_type='erdos',
                                                       initial_variable_generator=gaussian_cause)
            data, ground_truth = generator.generate()
            adj_matrix = nx.adjacency_matrix(ground_truth).todense()
            self.cache.append((torch.tensor(data.to_numpy()).float(), torch.tensor(adj_matrix).flatten().float()))

        return self.cache[idx]


class ConfoundedGraphDataset(GraphDatasetGenerator):
    def __init__(self, num_graphs=1000, num_nodes=4, num_samples=100, test=False, max_num_confounders: int = 2,
                 fraction_confounded_datasets: float = .5):
        super().__init__(num_graphs, num_nodes, num_samples, test)
        self.max_num_confounders = max_num_confounders
        self.num_confounders = [
            np.random.randint(1, max_num_confounders) if i < self.num_graphs * fraction_confounded_datasets else 0
            for i in range(self.num_graphs)]
        np.random.shuffle(self.num_confounders)
        self.cache = []
        self.sep = -2  # Seperator token between adjacency and confounding matrix

    def __getitem__(self, idx):
        while len(self.cache) <= idx:
            num_confounders = self.num_confounders[idx]
            generator = cdt.data.AcyclicGraphGenerator('linear', 'uniform',
                                                       nodes=self.num_nodes + num_confounders,
                                                       npoints=self.num_samples,
                                                       noise_coeff=1.,
                                                       # It seems like there is a bug in cdt and the graph is too dense
                                                       # if we don't divide by two
                                                       expected_degree=2 / 2., dag_type='erdos',
                                                       initial_variable_generator=gaussian_cause)
            data, ground_truth = generator.generate()
            hidden_confounders = np.random.choice(ground_truth.nodes, num_confounders)
            observed_variables = list(set(ground_truth.nodes) - set(hidden_confounders))

            data = data[observed_variables]
            data_tensor = torch.tensor(data.to_numpy()).float()
            adjacency_matrix, confounding_matrix = marginalize(ground_truth, observed_variables)
            # Concatenate adjacency and confounding matrix to fully specify ADMG with 2*num_nodes entries.
            admg_vector = np.concatenate([adjacency_matrix.flatten(), [self.sep], confounding_matrix.flatten()])
            admg_tensor = torch.tensor(admg_vector).float()
            self.cache.append((data_tensor, admg_tensor))
        return self.cache[idx]


def _marginalise_node(graph: nx.DiGraph, node: Any):
    """
    Remove node from graph and add bidirectional edge between all nodes that are confounded afterward.
    :param graph: original graph form which 'node' is removed
    :param node: node to remove from the graph
    :return: marginal graph without 'node'
    """
    # Add egdes from all predecessors to sucessors, i.e. replace  x -> node -> y with a new edge x -> y
    for pre in graph.predecessors(node):
        for succ in graph.successors(node):
            if pre != succ:
                graph.add_edge(pre, succ)
    for suc_one in graph.successors(node):  # Add bidericted edge if node is confounder of nodes
        for suc_two in graph.successors(node):
            if suc_two != suc_one:
                # if not graph.has_edge(suc_one, suc_two) and not graph.has_edge(suc_two, suc_one):
                graph.add_edge(suc_one, suc_two)
                graph.add_edge(suc_two, suc_one)

    graph.remove_node(node)


def marginalize(graph: nx.DiGraph, observed_variables: Iterable[Any]) -> Tuple[Any, Any]:
    """
    Remove all nodes not in 'observed_variables' and return two matrices. One is the marginalised adjacency matrix, the
    other one contains a 1 whenever two nodes have a hidden common cause. This coding is equivalent to an ADMG.
    :param graph: The graph from which to remove the unobserved confounders.
    :param observed_variables: The nodes that are not removed.
    :return: Tuple (adjacency_matrix, confounding_matrix), where adjacency_matrix[i, j] = 1 iff there is an edge between
    node i and node j and confounding_matrix[i, j] = 1 iff node i and node j have a hidden common cause.
    """
    # TODO this might not be the most elegant way to get the confounding matrix and marginalised adjacency matrix
    # Generate marginalised graph that contains bidirected edge whenever there is a hidden confounder
    g_marginalised = graph.copy()
    confounders = set(graph.nodes) - set(observed_variables)
    for node in confounders:
        _marginalise_node(g_marginalised, node)
    adjacency_matrix = nx.adjacency_matrix(g_marginalised, nodelist=observed_variables).todense()
    # Marginalised graph contains bidirected edges to indicate confounding. Remember original orientation in this matrix
    original_adjacency = nx.adjacency_matrix(graph, nodelist=observed_variables)
    confounding_matrix = np.zeros_like(adjacency_matrix, dtype=float)
    # For all entries, if bidirected add entry to confounding_matrix and set orientation to original direction
    for i in range(adjacency_matrix.shape[0]):
        for j in range(adjacency_matrix.shape[0]):
            if adjacency_matrix[i, j] == 1 and adjacency_matrix[j, i] == 1:
                confounding_matrix[i, j] = 1
                confounding_matrix[j, i] = 1
                # Reset bidirected edges to original orientation / remove if originally no edge between them
                if original_adjacency[i, j] == 0:
                    adjacency_matrix[i, j] = 0
                if original_adjacency[j, i] == 0:
                    adjacency_matrix[j, i] = 0
    return adjacency_matrix, confounding_matrix
