from scipy.stats import truncnorm, bernoulli
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import RobustScaler, StandardScaler
import os


def SimData_bench(input_size, seed, data_size, data_out=False):
    one_count, zero_count = 0, 0  # count of the samples in treatment group and control group, respectively
    one_treat, one_x, one_y, one_y_count = ([] for _ in range(4))
    zero_treat, zero_x, zero_y, zero_y_count = ([] for _ in range(4))

    np.random.seed(seed)
    while min(one_count, zero_count) < data_size // 2:
        # generate x
        ee = truncnorm.rvs(-10, 10)
        x_temp = truncnorm.rvs(-10, 10, size=input_size) + ee
        x_temp /= np.sqrt(2)

        # nodes in the first hidden layer
        h11 = np.tanh(2*x_temp[0]+1*x_temp[3])
        h12 = np.tanh(-x_temp[0]-2*x_temp[4])
        h13 = np.tanh(2*x_temp[1]-2*x_temp[2])
        h14 = np.tanh(-2*x_temp[3]+1*x_temp[4])

        # nodes in the second hidden layer
        h21 = np.tanh(-2*h11+h13)
        h22 = h12-h13
        h23 = np.tanh(h13-2*h14)

        # generate treatment
        prob = np.exp(h22)/(1 + np.exp(h22))
        treat_temp = bernoulli.rvs(p=prob)

        # nodes in the third hidden layer
        h31 = np.tanh(1*h21-2*treat_temp)
        h32 = np.tanh(-1*treat_temp+2*h23)

        # counterfactual nodes in the third hidden layer
        h31_count = np.tanh(1*h21-2*(1-treat_temp))
        h32_count = np.tanh(-1*(1-treat_temp)+2*h23)

        # generate outcome variable
        y_temp = -4*h31+2*h32 + np.random.normal(0, 1)

        # generate counterfactual outcome variable
        y_count_temp = -4*h31_count+2*h32_count + np.random.normal(0, 1)

        if treat_temp == 1:
            one_count += 1
            one_x.append(x_temp)
            one_y.append(y_temp)
            one_treat.append(treat_temp)
            one_y_count.append(y_count_temp)
        else:
            zero_count += 1
            zero_x.append(x_temp)
            zero_y.append(y_temp)
            zero_treat.append(treat_temp)
            zero_y_count.append(y_count_temp)

    x = np.array(one_x[:(data_size // 2)] + zero_x[:(data_size // 2)])
    y = np.array(one_y[:(data_size // 2)] + zero_y[:(data_size // 2)])
    treat = np.array(one_treat[:(data_size // 2)] + zero_treat[:(data_size // 2)])
    y_count = np.array(one_y_count[:(data_size // 2)] + zero_y_count[:(data_size // 2)])

    if data_out is True:
        data_out = np.concatenate((y.reshape((data_size, 1)), y_count.reshape((data_size, 1)),
                                   treat.reshape((data_size, 1)), x), axis=1)
        np.random.seed(seed)
        np.random.shuffle(data_out)
        data_out = pd.DataFrame(data_out, columns=['y', 'y_count', 'treat'] + ['x' + str(k) for k in range(input_size)])
        data_out.to_csv(os.path.join('./raw_data/sim', 'sim' + str(seed) + '.csv'), index=False)
    return y, treat, x, y_count


def true_cate(y, treat, y_count):
    cate = (y - y_count) * (2*treat-1)
    return cate


def ACIC_bench(dgp):
    csv_name = 'acic' + str(dgp) + '.csv'
    csv_dir = os.path.join('./raw_data/acic', csv_name)
    data = pd.read_csv(csv_dir)

    # extract column names for categorical variables
    cat_col = []
    for col in data.columns:
        if data[col].abs().max() <= 10:
            if len(data[col].unique()) <= data[col].max() + 1:
                cat_col.append(col)
    cat_col = cat_col[1:]

    cat_var = np.array(data[cat_col], dtype=np.float32)
    num_var = np.array(data.loc[:, ~data.columns.isin(['Y', 'A', *cat_col])], dtype=np.float32)
    y = np.array(data['Y'], dtype=np.float32)
    treat = np.array(data['A'], dtype=np.float32)

    # data preprocess
    x_scalar = StandardScaler()
    x_scalar.fit(num_var)
    num_var = np.array(x_scalar.transform(num_var))

    # concatenate preprocessed numerical variable and categorical variable
    x = np.concatenate((num_var, cat_var), axis=1)

    return y, treat, x


def Twins_bench():
    data = pd.read_csv("./raw_data/twins/twins_data.csv")

    y = np.array(data['y'])
    treat = np.array(data['treat'], dtype=np.float32)
    x = np.array(data.loc[:, ~data.columns.isin(['y', 'treat', 'counter'])], dtype=np.float32)

    return y, treat, x


def BRCA_bench():
    data = pd.read_csv("./raw_data/tcga/brca_data.csv")

    num_col = ['years_to_birth', 'date_of_initial_pathologic_diagnosis', 'number_of_lymph_nodes'] + \
              data.columns[23:].to_list()

    y = np.array(data['vital_status'])
    treat = np.array(data['radiation_therapy'])

    num_var = np.array(data[num_col], dtype=np.float32)
    cat_var = np.array(data.loc[:, ~data.columns.isin(['vital_status', 'radiation_therapy', 'days_to_death',
                                                            'days_to_last_followup', *num_col])])

    # data preprocess
    x_scalar = StandardScaler()
    x_scalar.fit(num_var)
    num_var = np.array(x_scalar.transform(num_var))

    # concatenate preprocessed numerical variable and categorical variable
    x = np.concatenate((cat_var, num_var), axis=1)
    return y, treat, x
