from scipy.stats import truncnorm, bernoulli, beta, norm, uniform, norm
from torch.utils.data import Dataset, random_split, ConcatDataset
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
import numpy as np
import pandas as pd
import torch


def data_preprocess(data, partition_seed, x_scale=True, y_scale=True, cross_val=None, cross_fit_no = None,
                    train_size=None, val_size=None, test_size=None, cat_var=False):
    """
    data: Dataset object
        map-style dataset (only map-style dataset has __len__() property)
    partition_seed: int
        seed to randomly partition the dataset into train set and validation set
    x_scale: bool
        whether the input variables will be scaled
    y_scale: bool
        whether the output variable will be scaled
    cross_fit_no: int
        the data subset that is going to be used as train set; note that we use three-fold cross fitting.
    cross_val: int
        the number of cross-validation folds
    train_size, val_size, test_size: int
        the number of data points for each set
    cat_var: bool
        if there is categorical variables in the features (categorical variables will not be normalized)
    """
    if (cross_val is not None) & (cross_val is not None):  # split the dataset for cross-validation
        data_size = data.__len__()
        size = int(data_size/cross_val)
        size_list = [size]*(cross_val-1)
        size_list.append(data_size-(cross_val-1)*size)
        cross_fit_set = random_split(data, size_list, generator=torch.Generator().manual_seed(partition_seed))
        val_set = cross_fit_set.pop(cross_fit_no-1)
        test_set = cross_fit_set.pop(cross_fit_no-1 if cross_fit_no < cross_val else 0)
        train_set = ConcatDataset(cross_fit_set)

        val_indices = val_set.indices
        test_indices = test_set.indices
        train_indices = list(np.concatenate([cross_fit_set[i].indices for i in range(cross_val-2)]).flat)
    else:  # randomly split the dataset
        train_set, val_set, test_set = random_split(data, [train_size, val_size, test_size],
                                                    generator=torch.Generator().manual_seed(partition_seed))

        val_indices = val_set.indices
        test_indices = test_set.indices
        train_indices = train_set.indices

    train_size = len(train_indices)
    val_size = len(val_indices)
    test_size = len(test_indices)

    x_scalar = StandardScaler()
    y_scalar = StandardScaler()

    if x_scale:
        if cat_var is True:
            x = data.num_var
        else:
            x = data.x

        x_scalar.fit(x[train_indices])
        x[train_indices] = np.array(x_scalar.transform(x[train_indices]))
        x[val_indices] = np.array(x_scalar.transform(x[val_indices]))
        x[test_indices] = np.array(x_scalar.transform(x[test_indices]))

    if y_scale:
        y_scalar.fit(data.y[train_indices])
        data.y[train_indices] = np.array(y_scalar.transform(data.y[train_indices]))
        data.y[val_indices] = np.array(y_scalar.transform(data.y[val_indices]))
        data.y[test_indices] = np.array(y_scalar.transform(data.y[test_indices]))

    return train_set, val_set, test_set, x_scalar, y_scalar, train_size, val_size, test_size


class ProxySim_Normal(Dataset):
    """
    Simulation dataset for proxy variable, with normal confounder

    input_size: the dimension of the input variable
    sample_size: the sample size for the generated dataset
    seed: random seed to replicate the dataset
    """
    def __init__(self, input_size, sample_size, seed):
        self.sample_size = sample_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        one_count, zero_count = 0, 0  # count of the samples in treatment group and control group, respectively
        one_treat, one_x, one_z = [], [], []
        zero_treat, zero_x, zero_z = [], [], []

        np.random.seed(seed)
        torch.manual_seed(seed)
        while min(one_count, zero_count) < self.sample_size // 2:
            # generate confounder
            z_temp = norm.rvs(size=5)

            # generate proxy
            ee = truncnorm.rvs(a=-10, b=10, loc=z_temp.mean())
            x_temp = truncnorm.rvs(a=-10, b=10, loc=z_temp.mean(), size=input_size) + ee
            x_temp /= np.sqrt(2)
            # x_temp = truncnorm.rvs(a=-10, b=10, loc=z_temp.mean(), size=input_size)

            # generate treatment
            temp = (norm.cdf(z_temp[0]) + norm.cdf(z_temp[2]) + norm.cdf(z_temp[4]))/3
            temp = beta.cdf(temp, 2, 4)
            prop = 0.25 * (1+temp)
            treat_temp = bernoulli.rvs(p=prop)

            if treat_temp == 1:
                one_count += 1
                one_z.append(z_temp)
                one_x.append(x_temp)
                one_treat.append(treat_temp)
            else:
                zero_count += 1
                zero_z.append(z_temp)
                zero_x.append(x_temp)
                zero_treat.append(treat_temp)

        z = np.array(one_z[:(self.sample_size // 2)] + zero_z[:(self.sample_size // 2)])
        x = np.array(one_x[:(self.sample_size // 2)] + zero_x[:(self.sample_size // 2)])
        treat = np.array(one_treat[:(self.sample_size // 2)] + zero_treat[:(self.sample_size // 2)])

        # generate outcome
        c = 5*z[..., 2]/(1+z[..., 3]**2) + 2*z[..., 4]
        f1 = 2/(1+np.exp(-z[..., 0]+0.5))
        f2 = 2/(1+np.exp(-z[..., 1]+0.5))
        ita = f1*f2 - f1*f2.mean()
        tau = 3+ita
        y = c + tau*treat + 0.25*norm.rvs(size=self.sample_size)

        self.x = torch.FloatTensor(x).to(self.device)
        self.y = torch.FloatTensor(y.reshape(self.sample_size, 1)).to(self.device)
        self.treat = torch.FloatTensor(treat).to(self.device)
        self.tau = torch.FloatTensor(tau.reshape(self.sample_size, 1)).to(self.device)

        # permutate the output
        permute_idx = torch.randperm(sample_size)
        self.x = self.x[permute_idx]
        self.y = self.y[permute_idx]
        self.treat = self.treat[permute_idx]
        self.tau = self.tau[permute_idx]

    def __len__(self):
        return int(self.sample_size)

    def __getitem__(self, idx):
        y = self.y[idx]
        treat = self.treat[idx]
        x = self.x[idx]
        tau = self.tau[idx]
        return y, treat, x, tau


# Twins Dataset
class TwinsData(Dataset):
    """
    load twins dataset
    all the columns are categorical (or binary) variables, no need to scale the data
    """
    def __init__(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        data = pd.read_csv("./experiment/data_preprocess/twins/twins_data.csv")
        self.data_size = len(data.index)

        self.y = torch.FloatTensor(np.array(data['y'])).long().to(device)
        self.treat = torch.FloatTensor(np.array(data['treat'], dtype=np.float32)).to(device)
        self.counter = torch.FloatTensor(np.array(data['counter'], dtype=np.float32)).to(device)
        self.x = torch.FloatTensor(np.array(data.loc[:, ~data.columns.isin(['y', 'treat', 'counter', 'mort_0', 'mort_1'])], dtype=np.float32)).to(device)

    def __len__(self):
        return int(self.data_size)

    def __getitem__(self, idx):
        y = self.y[idx]
        treat = self.treat[idx]
        x = self.x[idx]
        return y, treat, x


# brca_data dataset
class BRCA(Dataset):
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        data = pd.read_csv("./data_preprocess/brca/brca_data.csv")

        gene_col = data.columns[2:].to_list()
        self.data_size = len(data.index)

        self.y = torch.FloatTensor(np.array(data['vital_status'])).long().to(self.device)
        self.a = np.array(data.loc[:, ~data.columns.isin(['vital_status'])])
        self.num_var = np.array(data[gene_col], dtype=np.float32)
        self.cat_var = np.array(data[["radiation_therapy"]], dtype=np.float32)

    def __len__(self):
        return int(self.data_size)

    def __getitem__(self, idx):
        y = self.y[idx]
        a = torch.FloatTensor(np.concatenate((self.num_var[idx], self.cat_var[idx]))).to(self.device)
        return y, a


# class MultipleSim_Separable(Dataset):
#     def __init__(self, input_size, sample_size, confounder_size, seed, epsilon_idx=None, epsilon=0.1):
#         """
#         generate the simulation data for multiple cause scenario with separable confounding.
#         """
#         self.sample_size = sample_size
#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#         za_coef = uniform.rvs(loc=-2, scale=4, size=confounder_size, random_state=seed)
#         ay_coef = uniform.rvs(loc=-2, scale=4, size=input_size, random_state=seed)
#         zy_coef = uniform.rvs(loc=-2, scale=4, size=confounder_size, random_state=seed)

#         np.random.seed(seed)
#         a, z = np.zeros((sample_size, input_size)), np.zeros((sample_size, confounder_size))
#         k = confounder_size // 3
#         for i in range(sample_size):
#             z[i] = norm.rvs(size=confounder_size)
#             c1 = np.matmul(np.sin(z[i][0:k]), za_coef[0:k])
#             c2 = np.matmul(np.cos(z[i][k:(k*2)]), za_coef[k:(k*2)])
#             c3 = np.matmul(1/(1+np.exp(-z[i][(k*2):(k*3)]+0.5)), za_coef[(k*2):(k*3)])
#             a_mean = c1 +c2 +c3
#             a[i] = a_mean + norm.rvs(size=input_size)

#         if epsilon_idx is not None:
#             a[:, epsilon_idx] += epsilon

#         means_yz = np.matmul(z**2, zy_coef[:, None])
#         means_ya = np.matmul(a**2, ay_coef[:, None])
#         y = means_ya + means_yz + norm.rvs()

#         # self.z = torch.FloatTensor(z).to(self.device)
#         # self.a = torch.FloatTensor(a).to(self.device)
#         # self.y = torch.FloatTensor(y).to(self.device)
#         self.z = z
#         self.a = a
#         self.y = y

#     def __len__(self):
#         return int(self.sample_size)

#     def __getitem__(self, idx):
#         y = torch.FloatTensor(self.y[idx]).to(self.device)
#         a = torch.FloatTensor(self.a[idx]).to(self.device)
#         return y, a


class MultipleSim_Separable(Dataset):
    def __init__(self, input_size, sample_size, confounder_size, seed, epsilon_idx=None, epsilon=0.1):
        """
        generate the simulation data for multiple cause scenario with separable confounding.
        """
        self.sample_size = sample_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        za_coef = uniform.rvs(loc=-1, scale=2, size=confounder_size, random_state=seed)
        ay_coef = uniform.rvs(loc=-1, scale=2, size=input_size+1, random_state=seed)
        zy_coef = uniform.rvs(loc=-1, scale=2, size=3, random_state=seed)

        np.random.seed(seed)
        a, z, y = np.zeros((sample_size, input_size)), np.zeros((sample_size, confounder_size)), np.zeros((sample_size, 1))
        k = confounder_size // 3
        z_transform = np.zeros((sample_size, 3))
        for i in range(sample_size):
            z[i] = norm.rvs(size=confounder_size)
            c1 = np.matmul(np.sin(z[i][0:k]), za_coef[0:k])
            c2 = np.matmul(np.cos(z[i][k:(k*2)]), za_coef[k:(k*2)])
            c3 = np.matmul(1/(1+np.exp(-z[i][(k*2):(k*3)]+0.5)), za_coef[(k*2):(k*3)])
            c = c1 + c2 + c3
            u = uniform.rvs(loc=0, scale=1, size=input_size) # inverse CDF sampling
            a[i] = np.log(np.exp(c*u)*(1+np.exp(-c))-1)/c
            z_transform[i] = [c1, c2, c3]

        if epsilon_idx is not None:
            a[:, epsilon_idx] += epsilon

        poly_transform = PolynomialFeatures(interaction_only=True, include_bias=False)
        temp = poly_transform.fit_transform(a)
        temp = np.concatenate([a**2, temp.sum(axis=1)[:, None]], axis=1)

        means_yz = np.matmul(z_transform, zy_coef[:, None])
        means_ya = -np.matmul(temp, ay_coef[:, None])
        y = means_ya + means_yz + norm.rvs()

        # self.z = torch.FloatTensor(z).to(self.device)
        # self.a = torch.FloatTensor(a).to(self.device)
        # self.y = torch.FloatTensor(y).to(self.device)
        self.z = z
        self.a = a
        self.y = y

    def __len__(self):
        return int(self.sample_size)

    def __getitem__(self, idx):
        y = torch.FloatTensor(self.y[idx]).to(self.device)
        a = torch.FloatTensor(self.a[idx]).to(self.device)
        return y, a


class MultipleSim_NonSeparable(Dataset):
    def __init__(self, input_size, sample_size, confounder_size, seed,epsilon_idx=None, epsilon=0.1):
        """
        generate the simulation data for multiple cause scenario with non-separable confounding.
        """
        self.sample_size = sample_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        za_coef = uniform.rvs(loc=-1, scale=2, size=confounder_size, random_state=seed)
        ay_coef = uniform.rvs(loc=-1, scale=2, size=input_size, random_state=seed)

        np.random.seed(seed)
        a, z, y = np.zeros((sample_size, input_size)), np.zeros((sample_size, confounder_size)), np.zeros((sample_size, 1))
        k = confounder_size // 3
        c = np.zeros((sample_size, 1))
        for i in range(sample_size):
            z[i] = norm.rvs(size=confounder_size)
            c1 = np.matmul(np.sin(z[i][0:k]), za_coef[0:k])
            c2 = np.matmul(np.cos(z[i][k:(k*2)]), za_coef[k:(k*2)])
            c3 = np.matmul(1/(1+np.exp(-z[i][(k*2):(k*3)]+0.5)), za_coef[(k*2):(k*3)])
            # c[i] = 5 * (expit(c1 + c2 + c3) - 0.5)
            c[i] = c1 + c2 + c3
            u = uniform.rvs(loc=0, scale=1, size=input_size) # inverse CDF sampling
            # a[i] = np.log(np.exp(2*c[i]*u)*(1+np.exp(-2*c[i]))-1)/c[i]
            a[i] = np.log(np.exp(c[i]*u)*(1+np.exp(-c[i]))-1)/c[i]

        if epsilon_idx is not None:
            a[:, epsilon_idx] += epsilon

        y_mean1 = np.matmul(a**2, ay_coef[:, None])
        poly_transform = PolynomialFeatures(interaction_only=True, include_bias=False)
        temp = poly_transform.fit_transform(a)
        y_mean2 = c - np.multiply(c, temp.sum(axis=1)[:, None])
        # y_mean2 = c - np.multiply(c, a.sum(axis=1)[:, None])
        y = y_mean1 + y_mean2 + norm.rvs()
        # y_mean1 = np.multiply(c, np.matmul(a**2, ay_coef[:, None]))
        # y = y_mean1 + norm.rvs()

        # self.z = torch.FloatTensor(z).to(self.device)
        # self.a = torch.FloatTensor(a).to(self.device)
        # self.y = torch.FloatTensor(y).to(self.device)
        self.z = z
        self.a = a
        self.y = y

    def __len__(self):
        return int(self.sample_size)

    def __getitem__(self, idx):
        y = torch.FloatTensor(self.y[idx]).to(self.device)
        a = torch.FloatTensor(self.a[idx]).to(self.device)
        return y, a