import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import itertools
import torch

class CSVDataset(Dataset):
    def __init__(self, path, idx=-1):
        df = pd.read_csv(path)
        if idx != -1:
            if not isinstance(idx, int):
                flat_idx = list(itertools.chain.from_iterable([x] if isinstance(x, int) else x for x in idx))
                drop_idx = [df.columns[i] for i in flat_idx]
                df.drop(columns=drop_idx, inplace=True)
            else:
                df.drop(columns=df.columns[idx], inplace=True)
        self.X = df.values[:, :-1].astype('float32')
        self.y = df.values[:, -1].astype('int64')
        self.output = len(set(self.y))
        self.y = torch.nn.functional.one_hot(torch.tensor(self.y), num_classes=self.output).float()
        self.input = self.X.shape[1]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]

    def get_input(self):
        return self.input

    def get_output(self):
        return self.output

    def rand(self, fair, show=False):
        if not isinstance(fair, int):
            flat_fair = list(itertools.chain.from_iterable([x] if isinstance(x, int) else x for x in fair))
            for x in flat_fair:
                ls = np.unique(self.X[:, x])
                self.X[:, x] = np.random.choice(ls, size=len(self.X))
        else:
            ls = np.unique(self.X[:, fair])
            self.X[:, fair] = np.random.choice(ls, size=len(self.X))

    def fill(self, fair, v):
        if not isinstance(fair, int):
            flat_fair = list(itertools.chain.from_iterable([x] if isinstance(x, int) else x for x in fair))
            for i in range(self.X.shape[0]):
                for x in flat_fair:
                    self.X[i, x] = v
        else:
            for i in range(self.X.shape[0]):
                self.X[i, fair] = v

def prepare_data(path, idx=-1, fair=-1, move=None):
    # Load, split, randomize or intervene on sensitive attributes, return DataLoaders
    wine = pd.read_csv(path, header=None)
    from sklearn.model_selection import train_test_split
    train_data, test_data = train_test_split(wine, test_size=0.3, random_state=42)
    train_data.to_csv(path+'1', index=False)
    test_data.to_csv(path+'2', index=False)
    train, test0, test1 = CSVDataset(path+'1', idx), CSVDataset(path+'2', idx), CSVDataset(path+'2', idx)
    if move is not None:
        train.rand(move)
        test0.rand(move)
        test1.rand(move)
    if fair != -1:
        test0.fill(fair, 0)
        test1.fill(fair, 1)
        train.rand(fair)
    train_dl = DataLoader(train, batch_size=256, shuffle=True)
    test_dl_0 = DataLoader(test0, batch_size=99999, shuffle=False)
    test_dl_1 = DataLoader(test1, batch_size=99999, shuffle=False)
    return train_dl, test_dl_0, test_dl_1, train.get_input(), train.get_output(), len(train_dl.dataset), len(test_dl_0.dataset)
