import pandas as pd
import seaborn as sns
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from torch_geometric.data.data import Data

from datasets._adjacency import Adjacency
from datasets.utils import normalize_adj
from utils.constants import Cte
from utils.distributions import *
from datasets.transforms import ToTensor
'''
code partially from
https://github.com/adastra21/fairness_aware_preprocessing/blob/master/tutorials/03_tutorial_german_logistic-regression.ipynb
'''

nodes_list = ['sex',  # A
              'age',  # C
              'R',  # R
              'S']  ## S status of cheking account

cols_dict = {'sex': 'Obs 1',  # A
             'age': 'Obs 2',  # C
             'R': 'Obs 3',  # R
             'S': 'Obs 4',  # R repayment duration
             }

adj_edges = {'sex': ['R', 'S'],  # A#excluding 16
             'age': ['R', 'S'],  # Age
             'R': [],
             'S': [],
             }


# %%
class GermanSCM(torch.utils.data.Dataset):
    def __init__(self,
                 X, Y,
                 transform=None
                 ):
        self.transform = transform


        self.num_nodes = 4
        self.largest_dim = 12

        self.X = X.to_numpy()  # [800 x 16]
        self.Y = Y.to_numpy()

        self.X0, self.mask_X0 = self.fill_up_with_zeros(self.X)  # [800 x 48]


        self.total_num_features_x0 = self.X0.shape[1]
        # Intervention variables
        self.x_I = None  # Set variables intervened
        self.I_noise = False

        self.adj_object = None

        self.nodes_list = nodes_list

    @property
    def num_edges(self):
        return self.adj_object.num_edges

    @property
    def num_samples(self):
        return self.X.shape[0]

    @property
    def num_features(self):
        flatten = lambda t: [item for sublist in t for item in sublist]
        return sum(flatten(self.get_num_features_list()))

    def get_dim_to_scale_x0(self):
        return [12, 24,25]
    def get_dim_to_scale(self):
        return [1,2,3]

    def get_likelihood_list(self):
        # in topological order (A, C, R, S)
        return [[Cte.BERNOULLI], [Cte.DELTA], [Cte.DELTA, ] * 2, [Cte.CATEGORICAL, ] * 3]

    def get_num_features_list(self):
        # in topological order (A, C, R, S)
        return [[1], [1], [1, 1], [3, 5, 4]]

    def fill_up_with_zeros(self, X):
        # 800, 16
        node_dim_list = self.get_node_dimensions()
        node_cols = self.get_node_columns()
        X0 = np.zeros([X.shape[0], self.num_nodes*self.largest_dim])
        mask_X0 = np.zeros([1, self.num_nodes*self.largest_dim])
        for i, node in enumerate(range(self.num_nodes)):
            X0[:, i*self.largest_dim:(i*self.largest_dim  + node_dim_list[i])] = X[:, node_cols[i]]
            mask_X0[:,i*self.largest_dim:(i*self.largest_dim  + node_dim_list[i])] = 1.
        # in topological order (A, C, R, S)
        return X0 , torch.tensor(mask_X0).type(torch.bool) # (800, 480)

    def get_node_columns(self):
        # original dim: [1, 1, 2, 12]
        return  [[0], [1], [2, 3], list(range(4,16))]

    def get_node_dimensions(self):
        return [1, 1, 2, 12]

    def get_topological_nodes_pa(self):  # returns string
        return list(range(self.num_nodes)), [[], [], [0, 1], [0, 1]]

    def get_attributes_dict(self):  # returns string
        unfair_attributes = list(range(2, 16))  # R, S
        fair_attributes = [1]  # Age
        sensitive_attributes = [0]  # Sex
        return {'unfair_attributes': unfair_attributes,
                'fair_attributes': fair_attributes,
                'sens_attributes': sensitive_attributes}

    def get_intervention_list(self, in_distribution=True):
        # in fairness we are intervening on root nodes
        # list_Int = get_intervention_list(node_list=[1], std=self.X.std(0))
        list_Int = [({'sex': 0}, 'sex_0'), ({'sex': 1}, 'sex_1')]
        return list_Int

    def set_intervention(self, x_I, is_noise=False):
        self.x_I = {}
        var_to_idx = {'sex': 0}  # only intervention is on node sex
        self.I_noise = is_noise

        node_id_list = []

        for var, value in x_I.items():
            self.x_I[var_to_idx[var]] = value
            node_id_list.append(var_to_idx[var])

        self.adj_object.set_intervention(node_id_list)

    def diagonal_SCM(self):
        self.x_I = {}

        self.adj_object.set_diagonal()

    def clean_intervention(self):
        self.x_I = None
        self.I_noise = False
        self.adj_object.clean_intervention()

    def set_transform(self, transform):
        self.transform = transform

    def prepare_adj(self, normalize_A=None, add_self_loop=True):
        assert normalize_A is None, 'Normalization on A is not implemented'
        self.normalize_A = normalize_A
        self.add_self_loop = add_self_loop

        if add_self_loop:
            SCM_adj = np.eye(self.num_nodes, self.num_nodes)
        else:
            SCM_adj = np.zeros([self.num_nodes, self.num_nodes])

        nodes_list = list(cols_dict.keys())
        for node_i, children_i in adj_edges.items():
            row_idx = nodes_list.index(node_i)
            for child_j in children_i:
                print('child_j', child_j)
                print('child_j', child_j)
                print('nodes_list.index', nodes_list)
                SCM_adj[row_idx, nodes_list.index(child_j)] = 1
        # Create Adjacency Object
        self.dag = SCM_adj
        self.adj_object = Adjacency(SCM_adj)


    def prepare_data(self, normalize_A=None, add_self_loop=True):
        self.prepare_adj(normalize_A, add_self_loop)

    def sample_intervention(self, x_I, n_samples=1, return_set_nodes=False):
        if return_set_nodes:
            return [], []
        else:
            return []

    def get_counterfactual(self, x_factual, u_factual, x_I, is_noise=False, return_set_nodes=False):
        if return_set_nodes:
            return [], []
        else:
            return []

    def get_set_nodes(self):
        parent_nodes = []
        children_nodes = []
        intervened_nodes = []

        intervened_nodes.append(0)
        # todo: check, if (1,16) or (2,16)
        children_nodes.extend([2, 3])

        set_nodes = {'parents': parent_nodes,
                     'intervened': intervened_nodes,
                     'children': children_nodes}
        return set_nodes

    def sample_outcome(self, parents_dict=None, n_samples=1, u=None):
        return [], []

    def __getitem__(self, index):

        x = self.X0[index].copy().astype(np.float32)

        edge_index = self.adj_object.edge_index
        edge_attr = self.adj_object.edge_attr

        x_i, edge_index_i, edge_attr_i = None, None, None
        if self.x_I is not None:
            x_i = x.copy()
            if self.I_noise == False:
                if len(self.x_I) == 0:
                    for i, value in self.x_I.items():
                        x_i[i] = value

                    edge_index = self.adj_object.edge_index_i
                    edge_attr = self.adj_object.edge_attr_i
                else:
                    for i, value in self.x_I.items():
                        x_i[i] = value

                    edge_index_i = self.adj_object.edge_index_i
                    edge_attr_i = self.adj_object.edge_attr_i
            else:
                for i, value in self.x_I.items():
                    x_i[i] = x_i[i] + value
                edge_index_i = self.adj_object.edge_index_i
                edge_attr_i = self.adj_object.edge_attr_i

        if self.transform:
            x = self.transform(x).view(self.num_nodes, -1)
            if x_i is not None: x_i = self.transform(x_i).view(self.num_nodes, -1)
        else:
            x =ToTensor()(x).view(self.num_nodes, -1)
            if x_i is not None: x_i = ToTensor()(x_i).view(self.num_nodes, -1)



        data = Data(x=x,
                    mask=self.mask_X0.view(self.num_nodes, -1),
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    node_ids=torch.eye(self.num_nodes),
                    x_i=x_i,
                    edge_index_i=edge_index_i,
                    edge_attr_i=edge_attr_i,
                    num_nodes=self.num_nodes)

        return data

    def __len__(self):
        return len(self.X)

    def pairplot(self, X=None):
        X = self.X if X is None else X
        df = pd.DataFrame(data=X, columns=nodes_list)
        g = sns.pairplot(df)
        _ = g.fig.suptitle("German Credit with 4 (16) variables", y=1.08)
        return g

    def plot_intervention(self, x_I, n_samples=10000, only=False):
        X_inter = self.sample_intervention(x_I=x_I, n_samples=n_samples)

        all_vars = nodes_list

        vars_to_plot = set(all_vars) - set(x_I.keys())
        n_vars = len(vars_to_plot)

        f = plt.figure(figsize=(4 * n_vars, self.num_nodes))  # why 4?

        for i, v in enumerate(vars_to_plot):

            ax = f.add_subplot(1, n_vars, i + 1)
            if only:
                _ = sns.distplot(X_inter[:, all_vars.index(v)], ax=ax)
            else:

                if i == 0:
                    _ = sns.distplot(X_inter[:, all_vars.index(v)], ax=ax, label='Intervention')
                    _ = sns.distplot(self.X[:, all_vars.index(v)], ax=ax, label='Observations')
                else:
                    _ = sns.distplot(X_inter[:, all_vars.index(v)], ax=ax)
                    _ = sns.distplot(self.X[:, all_vars.index(v)], ax=ax)
            ax.set_title(f'Variable {v}')
        if not only: f.legend(loc='right')

        return f

    # Below: 99% garbage

    def build_torch_geometric_dataset(self, n_samples=1000, norm_adj=True):
        # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data
        adj_norm = normalize_adj(self.SCM_adj, how='row')
        X, A = torch.Tensor(self.X[:n_samples]), torch.Tensor(self.SCM_adj)

        A_indices = torch.nonzero(A).t()
        edge_weight = []
        row_list, col_list = adj_norm.nonzero()
        for row, col in zip(row_list, col_list):
            assert adj_norm[row, col] > 0
            edge_weight.append(adj_norm[row, col])
        edge_weight = torch.Tensor(np.array(edge_weight))

        data_list = []
        for i in range(n_samples):
            data_list.append(Data(x=X[i].unsqueeze(1),
                                  edge_index=A_indices,
                                  edge_weight=edge_weight)
                             )

        return Batch.from_data_list(data_list)
