import pytorch_lightning as pl
from sklearn import preprocessing
from torch_geometric.data import DataLoader
from torchvision import transforms as transform_lib

from datasets.triangle import TriangleSCM
from datasets.collider import ColliderSCM
from datasets.mgraph import MGraphSCM
from datasets.loan import LoanSCM
from datasets.chain import ChainSCM


from utils.constants import Cte
from datasets.transforms import ToOneHot, ToTensor

from torch_geometric.utils import degree

import torch


class TensorScaler:
    def __init__(self, scaler):
        self.scaler = scaler

    def transform(self, x):
        return torch.tensor(self.scaler.transform(x))

    def inverse_transform(self, x):
        return torch.tensor(self.scaler.inverse_transform(x))


class ToySCMDataModule(pl.LightningDataModule):
    name = 'toy_scm'

    def __init__(
            self,
            data_dir: str = "./",
            dataset_name: str = 'triangle',
            num_samples_tr: int = 10000,
            num_workers: int = 0,
            normalize: str = None,
            normalize_A: str = None,
            seed: int = 42,
            batch_size: int = 32,
            one_hot: bool = False,
            equations_type: str = 'linear',
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.dims = 1  # Dimension of the features
        self.data_dir = data_dir
        self.num_samples_tr = num_samples_tr

        self.data_is_toy = True

        self.num_workers = num_workers
        self.normalize = normalize
        self.normalize_A = normalize_A
        self.scaler = None
        self.seed = seed
        self.batch_size = batch_size
        self.target_transform = ToOneHot(10) if one_hot else None
        self.equations_type = equations_type
        self.topological_nodes = None
        self.topological_parents = None
        self.dataset_name = dataset_name
        print(dataset_name)
        self.attribute_dict = None

        if dataset_name == Cte.TRIANGLE:

            self.train_dataset = TriangleSCM(equations_type=equations_type, transform=None)

            self.valid_dataset = TriangleSCM(equations_type=equations_type, transform=None)

            self.test_dataset = TriangleSCM(equations_type=equations_type, transform=None)

        elif dataset_name == Cte.COLLIDER:

            self.train_dataset = ColliderSCM(equations_type=equations_type, transform=None)

            self.valid_dataset = ColliderSCM(equations_type=equations_type, transform=None)

            self.test_dataset = ColliderSCM(equations_type=equations_type, transform=None)
        elif dataset_name == Cte.MGRAPH:
            self.train_dataset = MGraphSCM(equations_type=equations_type, transform=None)

            self.valid_dataset = MGraphSCM(equations_type=equations_type, transform=None)

            self.test_dataset = MGraphSCM(equations_type=equations_type, transform=None)
        elif dataset_name == Cte.LOAN:
            self.train_dataset = LoanSCM(equations_type=equations_type, transform=None)

            self.valid_dataset = LoanSCM(equations_type=equations_type, transform=None)

            self.test_dataset = LoanSCM(equations_type=equations_type, transform=None)
        elif dataset_name == Cte.CHAIN:
            self.train_dataset = ChainSCM(equations_type=equations_type, transform=None)

            self.valid_dataset = ChainSCM(equations_type=equations_type, transform=None)

            self.test_dataset = ChainSCM(equations_type=equations_type, transform=None)
        else:
            raise NotImplementedError

        self.topological_nodes, self.topological_parents = self.train_dataset.get_topological_nodes_pa()

    @property
    def num_features(self):
        return self.dims

    @property
    def num_nodes(self):
        return self.train_dataset.num_nodes

    @property
    def edge_dimension(self):
        return self.train_dataset.num_edges



    def get_random_train_sampler(self):
        self.train_dataset.set_transform(self._default_transforms())

        def tmp_fn(num_samples):
            dataloader = DataLoader(self.train_dataset, batch_size=num_samples, shuffle=True)
            return next(iter(dataloader))

        return tmp_fn

    def get_deg(self, indegree=True, bincount=False):
        d_list = []
        idx = 1 if indegree else 0
        for data in self.train_dataset:
            d = degree(data.edge_index[idx], num_nodes=data.num_nodes, dtype=torch.long)
            d_list.append(d)

        d = torch.cat(d_list)
        if bincount:
            deg = torch.bincount(d, minlength=d.numel())
        else:
            deg = d

        return deg.float()

    def prepare_data(self):

        self.train_dataset.prepare_data(self.num_samples_tr, normalize_A=self.normalize_A, mode='train')
        self.valid_dataset.prepare_data(int(self.num_samples_tr * 0.5), normalize_A=self.normalize_A, mode='valid')
        self.test_dataset.prepare_data(int(self.num_samples_tr * 0.5), normalize_A=self.normalize_A, mode='test')
        if self.normalize == 'std':
            self.scaler = TensorScaler(preprocessing.StandardScaler().fit(self.train_dataset.X))
            # self.train_dataset.set_scaler(scaler)
            # self.valid_dataset.set_scaler(scaler)
            # self.test_dataset.set_scaler(scaler)
        elif self.normalize == 'power':
            self.scaler = preprocessing.PowerTransformer(method='yeo-johnson', standardize=True).fit(
                self.train_dataset.X)
            # self.train_dataset.set_scaler(scaler)
            # self.valid_dataset.set_scaler(scaler)
            # self.test_dataset.set_scaler(scaler)

        elif self.normalize == 'lip':
            raise NotImplementedError()
        elif self.normalize == 'norm':
            self.scaler = preprocessing.MinMaxScaler(feature_range=(0, 1)).fit(self.train_dataset.X)
            # self.train_dataset.set_scaler(scaler)
            # self.valid_dataset.set_scaler(scaler)
            # self.test_dataset.set_scaler(scaler)
        else:
            self.scaler = preprocessing.FunctionTransformer(func=lambda x: x,
                                                            inverse_func=lambda x: x)

    def train_dataloader(self):
        self.train_dataset.set_transform(self._default_transforms())
        loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        self.valid_dataset.set_transform(self._default_transforms())

        loader = DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def test_dataloader(self):
        self.test_dataset.set_transform(self._default_transforms())

        loader = DataLoader(
            self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True,
            pin_memory=True
        )
        return loader

    def _default_transforms(self):
        if self.scaler is not None:
            return transform_lib.Compose(
                [lambda x: self.scaler.transform(x.reshape(1, self.train_dataset.num_nodes)), ToTensor()]
            )
        else:
            return ToTensor()
