from abc import ABC, abstractmethod
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
import torch

from utils import DATA_DIR


class IndexTensorDataset(TensorDataset):
    @staticmethod
    def from_numpy(feat, sens, label):
        return IndexTensorDataset(torch.from_numpy(feat).float(),
                                  torch.from_numpy(sens).float(),
                                  torch.from_numpy(label).flatten().float())

    def __getitem__(self, index):
        return *super().__getitem__(index), index


class Data(pl.LightningDataModule, ABC):
    name = None
    feat_dim = None
    sens_dim = None
    simple_sens_cols = None

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if cls.name is None:
            raise ValueError(f"Data class {cls} must set a 'name'.")

    def __init__(self,
                 batch_size=32,
                 drop_sens_feat=False,
                 **kwargs):
        super().__init__()

        if len(kwargs) > 0:
            raise ValueError(f"The following kwargs were not used: {kwargs}")
        self.seed = int(os.environ["PL_GLOBAL_SEED"])
        self.batch_size = batch_size
        self.drop_sens_feat = drop_sens_feat
        self.train_data = None
        self.val_data = None
        self.test_data = None

    @abstractmethod
    def setup(self, stage: str) -> None:
        pass

    def info(self) -> dict:
        return {'feat_dim': self.feat_dim - (self.sens_dim if self.drop_sens_feat else 0),
                'sens_dim': self.sens_dim,
                'simple_sens_cols': self.simple_sens_cols}

    def dataloader(self, stage: str) -> DataLoader:
        dataset = getattr(self, f"{stage}_data")
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=stage == 'train')

    def train_dataloader(self):
        return self.dataloader("train")

    def val_dataloader(self):
        return self.dataloader("val")

    def test_dataloader(self):
        return self.dataloader("test")

    @property
    def data_dir(self):
        return os.path.join(DATA_DIR, self.__class__.name)

    def _preprocess_df(self, df, sens_columns=None, label_column=None, drop_columns=None, categorical_values=None,
                       mapping=None, mapping_label=None, normalise_columns=None, drop_rows=None):
        if drop_rows:
            for column in drop_rows:
                for value in drop_rows[column]:
                    df = df.drop(df[df[column] == value].index)

        if drop_columns:
            df = df.drop(columns=drop_columns)

        if mapping:
            for column in mapping:
                df[column] = df[column].squeeze().replace(mapping[column])

        df = df.dropna()

        if categorical_values:
            df = pd.get_dummies(df, columns=categorical_values)
            expanded_sens_columns = []
            for column in sens_columns:
                if column in categorical_values:
                    expanded_sens_columns += list(df.filter(regex=('^' + column + "_")).columns)
                else:
                    expanded_sens_columns += [column]
            sens_columns = expanded_sens_columns

        if mapping_label:
            df[label_column] = df[label_column].applymap(mapping_label)

        if normalise_columns:
            for column in normalise_columns:
                if column in sens_columns:
                    df[column] = (df[column] - df[column].min()) / (df[column].max() - df[column].min())
                else:
                    df[column] = (df[column] - df[column].mean()) / df[column].std()

        print(f"Sensitive features: {sens_columns}")
        if self.simple_sens_cols:
            sens_df = df[sens_columns]
            simple_sens_df = sens_df[sens_df.columns[self.simple_sens_cols].tolist()]
            print(f"Simple sensitive features set: {simple_sens_df.columns}")

            # Mostly a sanity check.
            # It is expected that simple sens columns are one-hot encodings.
            if (simple_sens_df.sum(axis=1) != 1).any():
                raise ValueError("Simple sensitive features are not mutually exclusive.")

        train_val, test = train_test_split(df, test_size=0.2)
        train, val = train_test_split(train_val, test_size=0.2)

        if self.drop_sens_feat:
            not_feat_cols = label_column + sens_columns
        else:
            not_feat_cols = label_column

        self.train_data, self.val_data, self.test_data = (IndexTensorDataset.from_numpy(
            sub_df.drop(columns=not_feat_cols).to_numpy(dtype="float64"),
            sub_df[sens_columns].to_numpy(dtype="float64"),
            sub_df[label_column].to_numpy(dtype="float64"))
            for sub_df in [train, val, test])


class CSVData(Data):
    name = "csv"

    def setup(self, stage: str) -> None:
        raise NotImplementedError

    def read(self, data_set, **preprocess_kwargs):
        df = pd.read_csv(os.path.join(DATA_DIR, data_set), delimiter=";")
        return self._preprocess_df(df, **preprocess_kwargs)


class XLSData(Data):
    name = "xls"

    def setup(self, stage: str) -> None:
        raise NotImplementedError

    def read(self, data_set, **preprocess_kwargs):
        df = pd.read_excel(os.path.join(DATA_DIR, data_set))
        return self._preprocess_df(df, **preprocess_kwargs)
