import numpy as np
import os
import torch
from torch.utils.data import Dataset
from module import check_exists, makedir_exist_ok, save, load
from sklearn.datasets import make_regression, make_classification


class SimulateR(Dataset):
    data_name = 'SimulateR'

    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        self.num_features = 1
        self.num_samples = [2500, 2500]
        self.mu = [30., 10.]
        self.bias = [0., 0.]
        self.x_sigma = [2., 2.]
        self.error_sigma = [10., 1.]
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))

    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input

    def __len__(self):
        return len(self.data)

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.download()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return

    def download(self):
        makedir_exist_ok(self.raw_folder)
        return

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_groups'])
        return fmt_str

    def make_data(self):
        np.random.seed(20241224)
        num_features = self.num_features
        num_groups = len(self.num_samples)
        weights = np.ones((num_features, num_groups)) + 1 + np.random.RandomState(20241224).rand(num_features, num_groups)
        train_data, test_data = [], []
        train_target, test_target = [], []
        train_sensitive, test_sensitive = [], []
        for i in range(num_groups):
            num_samples_i = self.num_samples[i]
            num_train_samples_i = int(num_samples_i * 0.8)
            mu = self.mu[i]
            error_sigma = self.error_sigma[i]
            x_sigma = self.x_sigma[i]
            bias = self.bias[i]
            X = np.random.RandomState(20241224).normal(loc=mu, scale=x_sigma, size=(num_samples_i, num_features))
            y = np.dot(X, weights[:, i]) + bias + np.random.RandomState(20241224).normal(loc=0, scale=error_sigma, size=num_samples_i)
            s = np.full(X.shape[0], i).astype(np.int64)
            y = y.reshape(-1, 1)
            X = X.astype(np.float32)
            y = y.astype(np.float32)
            perm = np.random.RandomState(seed=self.seed).permutation(len(X))
            train_data.append(X[perm[:num_train_samples_i]])
            train_target.append(y[perm[:num_train_samples_i]])
            train_sensitive.append(s[perm[:num_train_samples_i]])
            test_data.append(X[perm[num_train_samples_i:]])
            test_target.append(y[perm[num_train_samples_i:]])
            test_sensitive.append(s[perm[num_train_samples_i:]])

        train_data = np.concatenate(train_data)
        train_target = np.concatenate(train_target)
        train_sensitive = np.concatenate(train_sensitive)
        test_data = np.concatenate(test_data)
        test_target = np.concatenate(test_target)
        test_sensitive = np.concatenate(test_sensitive)
        num_groups = len(np.unique(train_sensitive))
        self.metadata = {'n_groups': num_groups, 'n_classes': 1, 'weights': weights}
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)


class SimulateC(Dataset):
    data_name = 'SimulateC'

    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        self.num_samples = [2500, 2500]
        self.num_features = 10
        self.target_size = 2
        self.class_weights = [(0.5, 0.5), (0.2, 0.8)]
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))


    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input

    def __len__(self):
        return len(self.data)

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str


    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.download()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return

    def download(self):
        makedir_exist_ok(self.raw_folder)
        return

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root, self.split)
        return fmt_str

    def make_data(self):
        num_classes = self.target_size
        class_weights = self.class_weights
        num_features = self.num_features
        num_groups = len(self.num_samples)
        train_data, test_data = [], []
        train_target, test_target = [], []
        train_sensitive, test_sensitive = [], []
        for i in range(num_groups):
            num_samples_i = self.num_samples[i]
            num_train_samples_i = int(num_samples_i * 0.8)
            X, y = make_classification(n_features=num_features, n_redundant=0, n_informative=num_features,
                                       n_clusters_per_class=1, n_classes=num_classes, flip_y=0.,
                                       weights=class_weights[i], n_samples=num_samples_i, random_state=20241224)
            s = np.full(X.shape[0], i).astype(np.int64)
            X = X.astype(np.float32)
            # add sensitive attribute to X
            X = np.concatenate([X, s.reshape(-1, 1)], axis=1)

            y = y.astype(np.int64)
            perm = np.random.RandomState(seed=self.seed).permutation(len(X))
            train_data.append(X[perm[:num_train_samples_i]])
            train_target.append(y[perm[:num_train_samples_i]])
            train_sensitive.append(s[perm[:num_train_samples_i]])
            test_data.append(X[perm[num_train_samples_i:]])
            test_target.append(y[perm[num_train_samples_i:]])
            test_sensitive.append(s[perm[num_train_samples_i:]])

        train_data = np.concatenate(train_data)
        train_target = np.concatenate(train_target)
        train_sensitive = np.concatenate(train_sensitive)
        test_data = np.concatenate(test_data)
        test_target = np.concatenate(test_target)
        test_sensitive = np.concatenate(test_sensitive)
        num_groups = len(np.unique(train_sensitive))
        self.metadata = {'n_groups': num_groups, 'n_classes': num_classes}
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)


class SimulateCM(SimulateC):
    data_name = 'SimulateCM'

    def __init__(self, root, split, seed, num_groups):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        self.num_features = 10
        self.target_size = 2
        self.num_groups = num_groups
        # samples_per_group = 1000
        # self.num_samples = [samples_per_group] * num_groups
        total_samples = 100000
        # add rest of the samples evenly to each group
        samples_per_group = total_samples // num_groups
        self.num_samples = [samples_per_group] * (num_groups - 1) + [total_samples - samples_per_group * (num_groups - 1)]
        # group parameters
        self.class_weights = []
        for i in range(num_groups):
            alpha = i / (num_groups - 1)
            cls1_weight = 0.5 - 0.4 * alpha
            cls2_weight = 0.5 + 0.4 * alpha
            self.class_weights.append((cls1_weight, cls2_weight))

        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))
    
    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'groups_{self.num_groups}', f'seed_{self.seed}')

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str