import os

import pandas as pd
import seaborn as sns
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from torch_geometric.data.data import Data

from datasets._adjacency import Adjacency
from datasets.utils import get_intervention_list
from datasets.utils import normalize_adj
from utils.args_parser import mkdir
from utils.distributions import *
from utils.constants import Cte

structural_eq_linear = {
    'x1': lambda u1: u1,
    'x2': lambda u2: u2,
    'x3': lambda u3, x1, x2: 0.05 * x1 + 0.25 * x2 + u3,
}

noises_distr_linear = {
    'u1': MixtureOfGaussians(probs=[0.5, 0.5], means=[-2, 1.5], vars=[1.5, 1]),
    'u2': Normal(0, 1),
    'u3': Normal(0, 1),
}



structural_eq_non_linear = {
    'x1': lambda u1: u1,
    'x2': lambda u2: u2,
    'x3': lambda u3, x1, x2: 0.05 * x1 + 0.25 * x2 ** 2 + u3,
}

noises_distr_non_linear = {
    'u1': MixtureOfGaussians(probs=[0.5, 0.5], means=[-2, 1.5], vars=[1.5, 1]),
    'u2': Normal(0, 0.1),
    'u3': Normal(0, 1),
}

structural_eq_non_additive = {
    'x1': lambda u1: u1,
    'x2': lambda u2: u2,
    'x3': lambda u3, x1, x2: -1 + 0.1 * np.sign(u3) * (x1 ** 2 + x2 ** 2) * u3,
}

noises_distr_non_additive = {
    'u1': MixtureOfGaussians(probs=[0.5, 0.5], means=[-2.5, 2.5], vars=[1, 1]),
    'u2': Normal(0, 0.25),
    'u3': Normal(0, 0.25 ** 2),
}

cols_dict = {'x1': 'Obs 1',
             'x2': 'Obs 2',
             'x3': 'Obs 3'}

adj_edges = {'x1': ['x3'],
             'x2': ['x3'],
             'x3': []}

# actionable and mutable
actionable_x = ['Obs 1', 'Obs 2', 'Obs 3']
non_actionable_x = []
# mutable but non-actionable
mutable_x = []
# inmutable but non-actionable
immutable_x = []


class ColliderSCM(torch.utils.data.Dataset):
    def __init__(self,
                 equations_type=Cte.LINEAR,
                 transform=None):
        
        self.transform = transform

        self.eq_type = equations_type

        if equations_type == Cte.LINEAR:
            self.structural_eq = structural_eq_linear
            self.noises_distr = noises_distr_linear
        elif equations_type == Cte.NONLINEAR:
            self.structural_eq = structural_eq_non_linear
            self.noises_distr = noises_distr_non_linear
        elif equations_type == Cte.NONADDITIVE:
            self.structural_eq = structural_eq_non_additive
            self.noises_distr = noises_distr_non_additive

        self.num_nodes = 3

        self.X = None
        self.U = None
        # Intervention variables
        self.x_I = None  # Set variables intervened
        self.I_noise = None  # Set variables intervened

        self.adj_object = None

    @property
    def num_edges(self):
        return self.adj_object.num_edges

    @property
    def num_samples(self):
        return self.X.shape[0]


    def get_topological_nodes_pa(self):  # returns string
        return [0, 1, 2], [[], [], [0, 1]]


    def get_intervention_list(self, in_distribution=True, std_list=None, node_list=None):
        # we are not intervening on x5 and x7 they are root nodes
        if node_list is None: node_list = [1, 2]
        return get_intervention_list(node_list=node_list, std=self.X.std(0), std_list=std_list)


    def set_intervention(self, x_I, is_noise=False):
        self.x_I = {}
        var_to_idx = {'x1': 0, 'x2': 1, 'x3': 2}
        self.I_noise = is_noise

        node_id_list = []

        for var, value in x_I.items():
            self.x_I[var_to_idx[var]] = value
            node_id_list.append(var_to_idx[var])

        self.adj_object.set_intervention(node_id_list)

    def diagonal_SCM(self):
        self.x_I = {}

        self.adj_object.set_diagonal()


    def clean_intervention(self):
        self.x_I = None
        self.I_noise = False
        self.adj_object.clean_intervention()

    def set_transform(self, transform):
        self.transform = transform

    def prepare_adj(self, normalize_A=None, add_self_loop=True):
        assert normalize_A is None, 'Normalization on A is not implemented'
        self.normalize_A = normalize_A
        self.add_self_loop = add_self_loop

        if add_self_loop:
            SCM_adj = np.eye(self.num_nodes, self.num_nodes)
        else:
            SCM_adj = np.zeros([self.num_nodes, self.num_nodes])

        nodes_list = list(cols_dict.keys())
        for node_i, children_i in adj_edges.items():
            row_idx = nodes_list.index(node_i)
            for child_j in children_i:
                SCM_adj[row_idx, nodes_list.index(child_j)] = 1

        # Create Adjacency Object
        self.adj_object = Adjacency(SCM_adj)



    def prepare_data(self, n_samples=1000, normalize_A=None, add_self_loop=True, mode='train'):
        self.prepare_adj(normalize_A, add_self_loop)
        X = np.zeros([n_samples, self.num_nodes])
        U = np.zeros([n_samples, self.num_nodes])
        folder =  mkdir(os.path.join('_data',  f'collider_{self.eq_type}'))

        X_file = os.path.join(folder, f'{mode}_{n_samples}_X.npy')
        U_file = os.path.join(folder, f'{mode}_{n_samples}_U.npy')

        if os.path.exists(X_file) and os.path.exists(U_file):
            X = np.load(X_file)
            U = np.load(U_file)
        else:
            for i in range(n_samples):
                x, u = self.sample()
                X[i, :] = x
                U[i, :] = u

            np.save(X_file, X)
            np.save(U_file, U)


        self.X = X.astype(np.float32)
        self.U = U.astype(np.float32)

    def sample(self, n_samples=1):

        x1, u1 = self.sample_obs(obs_id=1, n_samples=n_samples)
        x2, u2 = self.sample_obs(obs_id=2, n_samples=n_samples)
        x3, u3 = self.sample_obs(obs_id=3, parents_dict={'x1': x1, 'x2': x2}, n_samples=n_samples)

        return np.array([x1, x2, x3]), np.array([u1, u2, u3])

    def sample_obs(self, obs_id, parents_dict=None, n_samples=1, u=None):
        f = self.structural_eq[f'x{obs_id}']
        if u is None:
            u = np.array(self.noises_distr[f'u{obs_id}'].sample(n_samples))
        if not isinstance(parents_dict, dict):
            return f(u), u
        else:
            return f(u, **parents_dict), u

    def sample_intervention(self, x_I, n_samples=1, return_set_nodes=False):
        parent_nodes = []
        children_nodes = []
        intervened_nodes = []
        is_parent = True

        if 'x1' not in x_I:
            x1, u1 = self.sample_obs(obs_id=1, n_samples=n_samples)
            parent_nodes.append(0)
        else:
            x1 = np.array([x_I['x1'], ] * n_samples)
            is_parent = False  # Next non intervened nodes are children
            intervened_nodes.append(0)

        if 'x2' not in x_I:
            x2, u2 = self.sample_obs(obs_id=2, n_samples=n_samples)
            parent_nodes.append(1) if is_parent else children_nodes.append(1)
        else:
            x2 = np.array([x_I['x2'], ] * n_samples)
            is_parent = False  # Next non intervened nodes are children
            intervened_nodes.append(1)

        if 'x3' not in x_I:
            x3, u3 = self.sample_obs(obs_id=3, parents_dict={'x1': x1, 'x2': x2}, n_samples=n_samples)
            parent_nodes.append(2) if is_parent else children_nodes.append(2)
        else:
            x3 = np.array([x_I['x3'], ] * n_samples)
            is_parent = False  # Next non intervened nodes are children
            intervened_nodes.append(2)

        if return_set_nodes:
            set_nodes = {'parents': parent_nodes,
                         'intervened': intervened_nodes,
                         'children': children_nodes}
            return np.array([x1, x2, x3]).T, set_nodes
        else:
            return np.array([x1, x2, x3]).T

    def get_counterfactual(self, x_factual, u_factual, x_I, is_noise=False, return_set_nodes=False):
        is_tensor = isinstance(u_factual, torch.Tensor)
        if is_tensor:
            u_factual = u_factual.clone().numpy()
            x_factual = x_factual.clone().numpy()

        children_nodes = []
        intervened_nodes = []
        is_parent = True

        n_samples = u_factual.shape[0]

        # we need to be careful, this need to be done in right ordering,
        # e.g. for 7 var dataset for Amir: {G, A}, {L, E}, {D, I}, {S}
        if 'x1' not in x_I:
            x1, u1 = self.sample_obs(obs_id=1, u=u_factual[:, 0])
        else:
            is_parent = False
            x1 = (x_factual[:, 0] + x_I['x1']) if is_noise else np.array([x_I['x1'], ] * n_samples)
            intervened_nodes.append(0)
        if 'x2' not in x_I:
            x2, u2 = self.sample_obs(obs_id=2, u=u_factual[:, 1])
            if not is_parent: children_nodes.append(1)
        else:
            is_parent = False
            x2 = (x_factual[:, 1] + x_I['x2']) if is_noise else np.array([x_I['x2'], ] * n_samples)
            intervened_nodes.append(1)
        if 'x3' not in x_I:
            x3, u3 = self.sample_obs(obs_id=3, parents_dict={'x1': x1, 'x2': x2},
                                     u=u_factual[:, 2])
            if not is_parent: children_nodes.append(2)
        else:
            is_parent = False
            x3 = (x_factual[:, 2] + x_I['x3']) if is_noise else np.array([x_I['x3'], ] * n_samples)
            intervened_nodes.append(2)

        x_out = np.array([x1, x2, x3]).T

        if is_tensor: torch.tensor(x_out)

        if return_set_nodes:
            set_nodes = {'intervened': intervened_nodes,
                         'children': children_nodes}
            return x_out, set_nodes
        else:
            return x_out

    def __getitem__(self, index):
        x = self.X[index].copy()

        u = torch.tensor(self.U[index].copy())
        edge_index = self.adj_object.edge_index.clone()
        edge_attr = self.adj_object.edge_attr.clone()

        x_i, edge_index_i, edge_attr_i = None, None, None
        if self.x_I is not None:
            x_i = x.copy()
            if self.I_noise == False:
                if len(self.x_I) == 0:
                    for i, value in self.x_I.items():
                        x_i[i] = value

                    edge_index = self.adj_object.edge_index_i
                    edge_attr = self.adj_object.edge_attr_i
                else:
                    for i, value in self.x_I.items():
                        x_i[i] = value

                    edge_index_i = self.adj_object.edge_index_i
                    edge_attr_i = self.adj_object.edge_attr_i
            else:
                for i, value in self.x_I.items():
                    x_i[i] = x_i[i] + value
                edge_index_i = self.adj_object.edge_index_i
                edge_attr_i = self.adj_object.edge_attr_i

        if self.transform:
            x = self.transform(x).view(self.num_nodes, 1)
            if x_i is not None: x_i = self.transform(x_i).view(self.num_nodes, 1)



        data = Data(x=x,
                    u=u,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    node_ids=torch.eye(self.num_nodes),
                    x_i=x_i,
                    edge_index_i=edge_index_i,
                    edge_attr_i=edge_attr_i,
                    num_nodes=self.num_nodes)

        return data

    def __len__(self):
        return len(self.X)

    def pairplot(self, X=None):
        X = self.X if X is None else X
        df = pd.DataFrame(data=X, columns=['x1', 'x2', 'x3'])
        g = sns.pairplot(df)
        _ = g.fig.suptitle(f"{self.eq_type} Toy SCM (collider) with 3 variables", y=1.08)
        return g

    def plot_intervention(self, x_I, n_samples=10000, only=False):
        X_inter = self.sample_intervention(x_I=x_I, n_samples=n_samples)

        all_vars = ['x1', 'x2', 'x3']

        vars_to_plot = set(all_vars) - set(x_I.keys())
        n_vars = len(vars_to_plot)

        f = plt.figure(figsize=(4 * n_vars, self.num_nodes))  # why 4?

        for i, v in enumerate(vars_to_plot):

            ax = f.add_subplot(1, n_vars, i + 1)
            if only:
                _ = sns.distplot(X_inter[:, all_vars.index(v)], ax=ax)
            else:

                if i == 0:
                    _ = sns.distplot(X_inter[:, all_vars.index(v)], ax=ax, label='Intervention')
                    _ = sns.distplot(self.X[:, all_vars.index(v)], ax=ax, label='Observations')
                else:
                    _ = sns.distplot(X_inter[:, all_vars.index(v)], ax=ax)
                    _ = sns.distplot(self.X[:, all_vars.index(v)], ax=ax)
            ax.set_title(f'Variable {v}')
        if not only: f.legend(loc='right')

        return f

    # Below: 99% garbage

    def build_torch_geometric_dataset(self, n_samples=1000, norm_adj=True):
        # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data
        adj_norm = normalize_adj(self.SCM_adj, how='row')
        X, A = torch.Tensor(self.X[:n_samples]), torch.Tensor(self.SCM_adj)

        A_indices = torch.nonzero(A).t()
        edge_weight = []
        row_list, col_list = adj_norm.nonzero()
        for row, col in zip(row_list, col_list):
            assert adj_norm[row, col] > 0
            edge_weight.append(adj_norm[row, col])
        edge_weight = torch.Tensor(np.array(edge_weight))

        data_list = []
        for i in range(n_samples):
            data_list.append(Data(x=X[i].unsqueeze(1),
                                  edge_index=A_indices,
                                  edge_weight=edge_weight)
                             )

        return Batch.from_data_list(data_list)
