import numpy as np
import pytorch_lightning as pl
import torch
from aif360.datasets import GermanDataset
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torchvision import transforms as transform_lib

from datasets.german import GermanSCM
from datasets.transforms import ToOneHot, ToTensor
from utils.constants import Cte


class TensorScaler:
    def __init__(self, scaler):
        self.scaler = scaler

    def transform(self, x):
        return torch.tensor(self.scaler.transform(x))
    def inverse_transform(self, x):
        return torch.tensor(self.scaler.inverse_transform(x))



class MaskedTensorStandardScaler:
    def __init__(self, list_dim_to_scale_x0, list_dim_to_scale, total_num_dimensions):
        self.list_dim_to_scale_x0 = list_dim_to_scale_x0  # [0, 1, 4, 5 ,7 ,8]
        self.list_dim_to_scale = list_dim_to_scale  # [0, 1, 4, 5 ,7 ,8]
        self.total_num_dimensions = total_num_dimensions
        self.scaler = preprocessing.StandardScaler()


    def fit(self, x):
        if x.shape[1] != self.total_num_dimensions:
            self.scaler.fit(x[:, self.list_dim_to_scale_x0])
        else:
            self.scaler.fit(x[:, self.list_dim_to_scale])
    def transform(self, x):
        if x.shape[1] != self.total_num_dimensions:
            x_scaled = self.scaler.transform(x[:, self.list_dim_to_scale_x0])
            x[:, self.list_dim_to_scale_x0] = x_scaled
        else:
            x_scaled = self.scaler.transform(x[:, self.list_dim_to_scale])
            x[:, self.list_dim_to_scale] = x_scaled
        return torch.tensor(x)
    def inverse_transform(self, x):
        if x.shape[1] != self.total_num_dimensions:
            x_unscaled = self.scaler.inverse_transform(x[:, self.list_dim_to_scale_x0])
            x[:, self.list_dim_to_scale_x0] = torch.tensor(x_unscaled)
        else:
            x_unscaled = self.scaler.inverse_transform(x[:, self.list_dim_to_scale])
            x[:, self.list_dim_to_scale] = torch.tensor(x_unscaled)
        return x


class RealSCMDataModule(pl.LightningDataModule):
    name = 'real_scm'

    def __init__(
            self,
            data_dir: str = "./",
            dataset_name: str = 'german',
            num_samples_tr: int = 10000,
            num_workers: int = 16,
            normalize: str = None,
            normalize_A: str = None,
            seed: int = 42,
            batch_size: int = 32,
            one_hot: bool = False,
            num_samples_vl =0,
            num_samples_ts = 0,
            equations_type='linear',
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.dims = 1  # Dimension of the features
        self.data_dir = data_dir
        self.num_samples_tr = num_samples_tr
        self.num_samples_vl = num_samples_vl
        self.num_samples_ts = num_samples_ts
        if num_samples_vl == 0:
            self.num_samples_vl = int(self.num_samples_tr * 0.5)
        if num_samples_ts == 0:
            self.num_samples_ts = int(self.num_samples_tr * 0.5)

        self.data_is_toy = False
        print()
        self.num_workers = num_workers
        self.normalize = normalize
        self.normalize_A = normalize_A
        self.scaler = None
        self.seed = seed
        self.batch_size = batch_size
        self.target_transform = ToOneHot(10) if one_hot else None
        self.topological_nodes = None
        self.topological_parents = None
        self.dataset_name = dataset_name
        self.equations_type = equations_type
        print(dataset_name)
        self.attribute_dict = None

        self._shuffle_train = True

        if dataset_name == Cte.GERMAN:
            nodes_list = ['sex',  # A
                          'age',  # C
                          'credit_amount',  # R
                          'month',  # R repayment duration
                          'housing=A151', 'housing=A152', 'housing=A153',  # S
                          'savings=A61', 'savings=A62', 'savings=A63',
                          'savings=A64', 'savings=A65',  # S savings
                          'status=A11', 'status=A12',
                          'status=A13', 'status=A14']  ## S


            dataset = GermanDataset(protected_attribute_names=['sex'])
            dataset.labels = np.where(dataset.labels == 2, 0, 1)  # this is for y

            dataset.unfavorable_label = 0.0
            dataset.metadata['protected_attribute_maps']

            df = dataset.convert_to_dataframe()[0]

            X = df[nodes_list]
            y = df['credit']

            X_train, X_test_valid, y_train, y_test_valid = train_test_split(X, y, test_size=0.2, random_state=1)

            X_test, X_valid, y_test, y_valid = train_test_split(X_test_valid, y_test_valid, test_size=0.5,
                                                                random_state=1)

            self.train_dataset = GermanSCM(X_train, y_train, transform=None)

            self.valid_dataset = GermanSCM(X_valid, y_valid, transform=None)

            self.test_dataset = GermanSCM(X_test, y_test, transform=None)

            self.attributes_dict = self.train_dataset.get_attributes_dict()



        else:
            raise NotImplementedError

        self.topological_nodes, self.topological_parents = self.train_dataset.get_topological_nodes_pa()


    @property
    def num_features(self):
        return self.dims

    @property
    def num_nodes(self):
        return self.train_dataset.num_nodes

    @property
    def edge_dimension(self):
        return self.train_dataset.num_edges
    @property
    def edge_dimension_ancestors(self):

        return self.train_dataset.num_edges_ancestors


    @property
    def num_features_list(self):

        return self.train_dataset.get_num_features_list()

    @property
    def likelihood_list(self):
        return self.train_dataset.get_likelihood_list()

    def set_shuffle_train(self, value):
        self._shuffle_train = value
    def get_random_train_sampler(self):
        self.train_dataset.set_transform(self._default_transforms())

        def tmp_fn(num_samples):
            dataloader = DataLoader(self.train_dataset, batch_size=num_samples, shuffle=True)
            return next(iter(dataloader))

        return tmp_fn

    def get_deg(self, indegree=True, bincount=False):
        d_list = []
        idx = 1 if indegree else 0
        for data in self.train_dataset:
            d = degree(data.edge_index[idx], num_nodes=data.num_nodes, dtype=torch.long)
            d_list.append(d)

        d = torch.cat(d_list)
        if bincount:
            deg = torch.bincount(d, minlength=d.numel())
        else:
            deg = d

        return deg.float()


    def  get_normalized_X(self, mode='test'):
        if mode == 'train':
            return self.scaler.transform(self.train_dataset.X.copy())
        elif mode == 'test':
            return self.scaler.transform(self.test_dataset.X.copy())
        elif mode == 'valid':
            return self.scaler.transform(self.valid_dataset.X.copy())
        else:
            raise NotImplementedError


    def prepare_data(self):

        self.train_dataset.prepare_data(normalize_A=self.normalize_A)
        self.valid_dataset.prepare_data(normalize_A=self.normalize_A)
        self.test_dataset.prepare_data(normalize_A=self.normalize_A)

        if self.normalize == 'std':
            self.scaler = MaskedTensorStandardScaler(list_dim_to_scale_x0=self.train_dataset.get_dim_to_scale_x0(),
                                                     list_dim_to_scale=self.train_dataset.get_dim_to_scale(),
                                                     total_num_dimensions=self.train_dataset.num_features)
            self.scaler.fit(self.train_dataset.X0)
        elif self.normalize == 'lip':
            raise NotImplementedError()
        else:
            self.scaler = preprocessing.FunctionTransformer(func=lambda x: x,
                                                            inverse_func=lambda x: x)

    def train_dataloader(self):

        self.train_dataset.set_transform(self._default_transforms())
        loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self._shuffle_train,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        self.valid_dataset.set_transform(self._default_transforms())

        loader = DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def test_dataloader(self):
        self.test_dataset.set_transform(self._default_transforms())

        loader = DataLoader(
            self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True,
            pin_memory=True
        )
        return loader

    def _default_transforms(self):
        if self.scaler is not None:
            return transform_lib.Compose(
                [lambda x: self.scaler.transform(x.reshape(1,  self.train_dataset.total_num_features_x0)), ToTensor()]
            )
        else:
            return ToTensor()
