import pandas as pd
import torch
from scipy.sparse import coo_matrix
import numpy as np
from torch.utils.data import Dataset

def sigmoid(x):   
    return 1 / (1 + np.exp(-x))
        

def generate_simple(beta,n_samples=1500, observational=True, device='cpu', plotting=False,log_gamma=10.0):
    # Generate observed features X ~ Unif[-2``, 2]
    if plotting:
        X = np.linspace(-2, 2, n_samples).reshape(-1, 1)
    else:
        X = np.random.uniform(-2, 2, (n_samples, 1))
    
    # Generate binary unobserved confounders u ~ Bern(1/2)
    U = np.random.binomial(1, beta, (n_samples, 1)) 
    
    # Complete propensity scores e(x, u)
    nominal_propensity = sigmoid(0.75 * X + 0.5)
    # Placeholder values for alpha_t(x; Γ*) and beta_t(x; Γ*), need actual function or values
    gamma = np.exp(log_gamma)
   
    #nominal_propensity = sigmoid(0.75 * X + 0.5)  


    alpha_t = (1 / (gamma * nominal_propensity)) + 1 - (1 / gamma)
    beta_t = (gamma / nominal_propensity) + 1 - gamma

    complete_propensity = U/ alpha_t + (1 - U) / beta_t

    
    if observational == True:
        T = (np.random.rand(n_samples, 1) < complete_propensity).astype(float)
    else:
        T = np.random.binomial(1, 0.5, (n_samples, 1)).astype(float)
    
    # Outcome model Y(t)

    epsilon = np.random.normal(0, 1, (n_samples, 1))  # Noise
    mu1 = X + 2 - 2 * np.sin(2 * X) - 2 * (2 * beta - 1) * (1 + 0.5 * X)
    # average over the unobserved confounders
    mu0 = - X - 2 + 2 * np.sin(2 * X) - 2 * (2 * beta - 1) * (1 + 0.5 * X)
    
    Y = (2 * T - 1) * X + 2 * (2 * T - 1) - 2 * np.sin(2 * (2 * T - 1) * X) - 2 * (2 * U - 1) * (1 + 0.5 * X) + epsilon
    
    #Y = mu0 * (1 - T) + mu1 * T + epsilon
    # Convert to torch tensors
    X = torch.tensor(X, dtype=torch.float32, device=device)
    T = torch.tensor(T, dtype=torch.float32, device=device)
    Y = torch.tensor(Y, dtype=torch.float32, device=device)
    U = torch.tensor(U, dtype=torch.float32, device=device)
    mu0 = torch.tensor(mu0, dtype=torch.float32, device=device)
    mu1 = torch.tensor(mu1, dtype=torch.float32, device=device)

    return X, T, Y, mu0,mu1, U


def generate_high_dimensional(beta=0.3,n_samples=1500, dim_X=10, dim_U=5, observational=True, device='cpu', plotting=False, log_gamma=10.0):
    # Generate observed features X ~ Unif[-2, 2] for each dimension
    X = np.random.uniform(-2, 2, (n_samples, dim_X))
    
    # Generate binary unobserved confounders U ~ Bern(1/2) for each dimension
    U = np.random.binomial(1, beta, (n_samples, dim_U))
    
    sigma = 0.1
    # Calculate the complete propensity scores e(x, u)
    thet_X = np.random.normal(0, sigma, (dim_X, 1))
    # make the first element of theta X equal to 1
    thet_X[0] = 1
    mean_X = np.dot(X, thet_X)
    #print("the shape of mean_X is: ", mean_X.shape)
    nominal_propensity = sigmoid(0.75 * mean_X + 0.5)
    gamma = np.exp(log_gamma)
    
    alpha_t = (1 / (gamma * nominal_propensity)) + 1 - (1 / gamma)
    beta_t = (gamma / nominal_propensity) + 1 - gamma
    
    beta_U = np.random.normal(0, sigma, (dim_U, 1))
    # make the first element of beta U equal to 1
    beta_U[0] = 1
    complete_propensity = np.mean(U, axis=1, keepdims=True) / alpha_t + (1 - np.mean(U, axis=1, keepdims=True)) / beta_t
    
    if observational:
        T = (np.random.rand(n_samples, 1) < complete_propensity).astype(float)
    else:
        T = np.random.binomial(1, 0.5, (n_samples, 1)).astype(float)
    
    # Outcome model Y(t)
    epsilon = np.random.normal(0, 1.0, (n_samples, 1))  # Noise

    # generate beta the coefficients from a gausian
    beta_X = np.random.normal(0, sigma, (dim_X, 1))

    
    X_sum_beta = np.dot(X, beta_X)
    theta_U = np.random.normal(0, sigma, (dim_U, 1))
    # make the first element of theta U equal to 1
    theta_U[0] = 1
    U_mean = np.dot(U, theta_U)

    
    beta_theta = np.sum(theta_U) * beta
    
    mu1 = X_sum_beta + 1 - 2 * np.sin(2 * X_sum_beta) - 2 * (2 * beta_theta - 1) * (1 + 0.5 * X_sum_beta)
    mu0 = -X_sum_beta - 1 + 2 * np.sin(2 * X_sum_beta) - 2 * (2 * beta_theta - 1) * (1 + 0.5 * X_sum_beta)
    
    Y = (2 * T - 1) * X_sum_beta + (2 * T - 1) - 2 * np.sin(2 * (2 * T - 1) * X_sum_beta) - 2 * (2 * U_mean - 1) * (1 + 0.5 * X_sum_beta) + epsilon
    
    # Convert to torch tensors
    X = torch.tensor(X, dtype=torch.float32, device=device)
    T = torch.tensor(T, dtype=torch.float32, device=device)
    Y = torch.tensor(Y, dtype=torch.float32, device=device)
    U = torch.tensor(U, dtype=torch.float32, device=device)
    mu0 = torch.tensor(mu0, dtype=torch.float32, device=device)
    mu1 = torch.tensor(mu1, dtype=torch.float32, device=device)

    return X, T, Y, mu0, mu1, U



# Code based on the CORNet repository
# TODO: add reference to the CORNet repository

import csv
import pandas as pd
from sklearn.model_selection import train_test_split
import pyreadr
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import StandardScaler
import numpy as np
import os

def read_START_data(q_prime=0.1):
    print("Current working directory:", os.getcwd())
    file_path = os.path.join(os.getcwd(), 'src/data/STAR_Students.RData')
    result = pyreadr.read_r(file_path)

    dat = result['x']
    df = dat
    
    #Drop nans
    treatment_indicator = df['g1classtype'].notna()
    df = df[treatment_indicator]
    
    #Only consider regular and small classes
    reg_size = df['g1classtype']=='REGULAR CLASS'
    small_size = df['g1classtype']=='SMALL CLASS'
    size_indicator = reg_size|small_size
    df = df[size_indicator]
     
    #Remove students with missing outcome 
    df = df[df['g1tlistss'].notna()]
    df.shape
    
    df = df[df['g1treadss'].notna()]
    df.shape
    
    df = df[df['g1tmathss'].notna()]
    df.shape
    
    #Treatment and outcome variables
    df['treatment'] = (df['g1classtype']=='SMALL CLASS').astype(int)
    df['outcome'] = df['g1tlistss'] + df['g1treadss'] + df['g1tmathss']
    df['rural'] = (df['g1surban'] == 'RURAL') | (df['g1surban'] == 'INNER CITY')
    
    #Only use covariates: gender, race, birth month, birthday, birth year, free lunch given or not, teacher id
    df_all = df[['outcome', 'treatment', 'gender', 'race', 'birthmonth', 'birthday', 'birthyear', 'g1freelunch', 'g1tchid', 'rural']]
    
    #Remove students with missing covariates
    for i in ['gender', 'race', 'birthmonth', 'birthday', 'birthyear', 'g1freelunch', 'g1tchid', 'rural']:
        df_all = df_all[df_all[i].notna()]
        df = df[df[i].notna()]
    
    #Ordinal transformer
    for i in ['gender', 'race', 'birthmonth', 'g1freelunch', 'rural']:
        enc = OrdinalEncoder(dtype=int)
        x = np.array(df_all[i].values).reshape(-1,1)
        enc.fit(x)
        x_transform = enc.transform(x)
        df_all[i] = x_transform
    
    #Center/scale data
    scaler = StandardScaler()
    scaler.fit(df_all[['birthday', 'birthyear', 'g1tchid', 'outcome']].values)
    df_all[['birthday', 'birthyear', 'g1tchid', 'outcome']] = scaler.transform(df_all[['birthday', 'birthyear', 'g1tchid', 'outcome']])

    rct_indicator1 = (df_all['birthday'] < 0)*np.random.binomial(1, 0.5, df_all['birthday'].shape[0]) + (df_all['birthday'] >= 0)*np.random.binomial(1, 0.1, df_all['birthday'].shape[0])
    
    rct_indicator = (rct_indicator1)>0
    
    df_RCT = df_all[rct_indicator==1]
    df_OS = df_all[rct_indicator==0]

    indicator = df_OS['treatment'] == 0 
    df_OS_control = df_OS[indicator]
    
    #Introduce confounding
    indicator = df_OS['treatment'] == 1
    df_OS_treated = df_OS[indicator]
    mean = df_OS_treated['outcome'].mean() 
    std = df_OS_treated['outcome'].std() 
    indicator_treat =  df_OS_treated['outcome'] > mean+std
    

    indicator_control =  df_OS_control['outcome'] < mean-std
    
    df_OS_treated_upper_half = df_OS_treated[indicator_treat]
    
    df_OS_control_lower_half = df_OS_control[indicator_control]    
    
    df_unc = df_RCT
    df_conf = pd.concat((df_OS_control_lower_half,
                         df_OS_treated_upper_half))

    #Unconfouded data
    x_unc = np.array(df_unc.values[:, 2:], dtype = 'float64')
    t_unc = np.array(df_unc['treatment'].values.reshape(-1,1), dtype = 'float64')
    y_unc = np.array(df_unc['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    #Confounded data   
    x_conf = np.array(df_conf.values[:, 2:], dtype = 'float64')
    t_conf = np.array(df_conf['treatment'].values.reshape(-1,1), dtype = 'float64')
    y_conf = np.array(df_conf['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    #Test on ALL\Unc  
    df_test = df_all[~df_all.isin(df_unc)].dropna()
    
    x_test = np.array(df_test.values[:, 2:], dtype = 'float64')
    t_test = np.array(df_test['treatment'].values.reshape(-1,1), dtype = 'float64')
    y_test = np.array(df_test['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    #all
    x_all = np.array(df_all.values[:, 2:], dtype = 'float64')
    t_all = np.array(df_all['treatment'].values.reshape(-1,1), dtype = 'float64')
    y_all = np.array(df_all['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    
    return {'x_unc': x_unc, 't_unc': t_unc, 'y_unc': y_unc, 'x_conf': x_conf, 't_conf': t_conf, 'y_conf': y_conf, 'x_test': x_test, 't_test': t_test, 'y_test': y_test, 'x_all': x_all, 't_all': t_all, 'y_all': y_all}

import random
from sklearn import preprocessing

def upload_actg(n_unc=500):
    file_path = os.path.join(os.getcwd(), 'src/data/actg175.csv')
    df = pd.read_csv(file_path,index_col=0, header=0)
    
    cd4_baseline = df['cd40']
    cd4_20 = df['cd420']
    
    outcome = (cd4_20 - cd4_baseline).values
    outcome_norm = preprocessing.scale(outcome)
    
    cov_cont = df[['age', 'wtkg', 'cd40', 'karnof', 'cd80']]
    cov_cont_norm = preprocessing.scale(cov_cont)
    
    cov_bin = df[['gender', 'homo', 'race', 'drugs', 'symptom', 'str2', 'hemo']]
    cov_bin_val = cov_bin.values 
    t = df[['arms']].values
    
    data = np.concatenate((cov_cont_norm, cov_bin_val, t.reshape(-1,1), outcome_norm.reshape(-1,1)), axis=1)
    data.shape
    
    #Only focus on one arm (0=zidovudine, 1=zidovudine and didanosine, 2=zidovudine and zalcitabine,3=didanosine)
    t_1 = 2
    t_0 = 0
    t_ind = (t == t_0) + (t == t_1)
    
    data_rct = data[t_ind.flatten()]
    #change treatment sign to 1
    data_rct[:,-2] = np.where(data_rct[:,-2] == 2, 1, 0)
    
    #All data
    x_all = data_rct[:,:-2]
    t_all = data_rct[:,-2].reshape(-1,1)
    y_all = data_rct[:,-1].reshape(-1,1)

    #UNC selection
    ind_unc = random.sample(range(x_all.shape[0]), n_unc)
    x_unc = x_all[ind_unc, ]
    t_unc = t_all[ind_unc, ].reshape(-1,1)
    y_unc = y_all[ind_unc, ].reshape(-1,1)
    
    x_not_unc = np.delete(x_all, ind_unc, axis = 0)
    t_not_unc = np.delete(t_all, ind_unc)
    y_not_unc = np.delete(y_all, ind_unc)
    
    #CONF selection - balanced gender - take all females and sample male s.t. ~ balanced
    #Among males, introduce confounding
    ind_f = (x_not_unc[:, 5] == 0)
    ind_m = (x_not_unc[:, 5] == 1)
    
    ind_m_t = (t_not_unc == 1) * ind_m
    mean = y_not_unc[ind_m_t].mean()
    std = y_not_unc[ind_m_t].std()
    ind_m_t_upper = y_not_unc[ind_m_t] > mean
    
    x_m_t_upper = x_not_unc[ind_m_t,:][ind_m_t_upper,:]
    t_m_t_upper = t_not_unc[ind_m_t][ind_m_t_upper]
    y_m_t_upper = y_not_unc[ind_m_t][ind_m_t_upper]

    ind_m_c = (t_not_unc == 0) * ind_m
    mean = y_not_unc[ind_m_c].mean()
    std = y_not_unc[ind_m_c].std()
    ind_m_c_lower = y_not_unc[ind_m_c] < mean

    x_m_c_lower = x_not_unc[ind_m_c,:][ind_m_c_lower,:]
    t_m_c_lower = t_not_unc[ind_m_c][ind_m_c_lower]
    y_m_c_lower = y_not_unc[ind_m_c][ind_m_c_lower]

    x_f = x_not_unc[ind_f,:]
    t_f = t_not_unc[ind_f]
    y_f = y_not_unc[ind_f]
    
    x_conf = np.concatenate((x_m_t_upper, x_m_c_lower, x_f))
    t_conf = np.concatenate((t_m_t_upper, t_m_c_lower, t_f)).reshape(-1,1)
    y_conf = np.concatenate((y_m_t_upper, y_m_c_lower, y_f)).reshape(-1,1)

    
    x_test = x_not_unc
    t_test = t_not_unc.reshape(-1,1)
    y_test = y_not_unc.reshape(-1,1)

    return {'x_unc': x_unc, 't_unc': t_unc, 'y_unc': y_unc, 'x_conf': x_conf, 't_conf': t_conf, 'y_conf': y_conf, 'x_test': x_test, 't_test': t_test, 'y_test': y_test, 'x_all': x_all, 't_all': t_all, 'y_all': y_all}


import pandas as pd
from sklearn import preprocessing
import numpy as np
import random
from sklearn.preprocessing import StandardScaler

def sample_jobs(n_unc):
    file_path = os.path.join(os.getcwd(), 'src/data/Jobs_Lalonde_Data.csv.gz')
    df = pd.read_csv(file_path)
    
    df['outcome'] = df['RE78']
    
    #Center/scale data
    scaler = StandardScaler()
    scaler.fit(df[['Age', 'Education', 'outcome']].values)
    df[['Age', 'Education', 'outcome']] = scaler.transform(df[['Age', 'Education', 'outcome']])

    
    #Among the ones in the original experiment, sample more that are older -> covariate shift in age    
    n_rct = 297 + 425
    
    median_age = df['Age'].median() 
    ind_l = df[:n_rct]['Age'] < median_age
    ind_u = df[:n_rct]['Age'] >= median_age
    df_1 = df[:n_rct][ind_l].sample(int(95/100*n_unc))
    df_2 = df[:n_rct][ind_u].sample(int(5/100*n_unc))
    df_unc = pd.concat((df_1, df_2))
    
    df_not_unc = df[~df.isin(df_unc)].dropna()
    df_not_unc.shape
 
    #Only little confounding for the treated - we don't have many samples 
    ind_t = df_not_unc['Treatment'] == 1
    df_not_unc_t = df_not_unc[ind_t] 
    mean = df_not_unc_t['outcome'].mean() 
    std = df_not_unc_t['outcome'].std() 
    ind_t =  df_not_unc_t['outcome'] > mean+0.25*std
    
    df_not_unc_t_upper_half = df_not_unc_t[ind_t]  
    
    #Confounding for the treated - we have many samples
    ind_c = df_not_unc['Treatment'] == 0
    df_not_unc_c = df_not_unc[ind_c] 
    mean = df_not_unc_c['outcome'].mean() 
    std = df_not_unc_c['outcome'].std() 
    ind_c =  df_not_unc_c['outcome'] < mean-1.2*std
    
    df_not_unc_c_lower_half = df_not_unc_c[ind_c]  
    
    df_conf = pd.concat((df_not_unc_t_upper_half, df_not_unc_c_lower_half))
    
    #Unconfouded data
    x_unc = np.array(df_unc.values[:, :6], dtype = 'float64')
    t_unc = np.array(df_unc['Treatment'].values.reshape(-1,1), dtype = 'float64')
    y_unc = np.array(df_unc['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    #Confounded data
    x_conf = np.array(df_conf.values[:, :6], dtype = 'float64')
    t_conf = np.array(df_conf['Treatment'].values.reshape(-1,1), dtype = 'float64')
    y_conf = np.array(df_conf['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    #Test on ALL\Unc   
    df_test = df_not_unc
    
    x_test = np.array(df_test.values[:, :6], dtype = 'float64')
    t_test = np.array(df_test['Treatment'].values.reshape(-1,1), dtype = 'float64')
    y_test = np.array(df_test['outcome'].values.reshape(-1,1), dtype = 'float64')
    
    #all
    x_all = np.array(df[:n_rct].values[:, :6], dtype = 'float64')
    t_all = np.array(df[:n_rct]['Treatment'].values.reshape(-1,1), dtype = 'float64')
    y_all = np.array(df[:n_rct]['outcome'].values.reshape(-1,1), dtype = 'float64')

    return {'x_unc': x_unc, 't_unc': t_unc, 'y_unc': y_unc, 'x_conf': x_conf, 't_conf': t_conf, 'y_conf': y_conf, 'x_test': x_test, 't_test': t_test, 'y_test': y_test, 'x_all': x_all, 't_all': t_all, 'y_all': y_all}



class causal_inference_dataset(Dataset):
    def __init__(self, X, T, Y,mu0,mu1,U):
        self.X = X
        self.T = T
        self.Y = Y
        self.mu0 = mu0
        self.mu1 = mu1
        self.U = U

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return {
            'X': self.X[idx],
            'T': self.T[idx],
            'Y': self.Y[idx],
            'U': self.U[idx],
            'mu0': self.mu0[idx],
            'mu1': self.mu1[idx]
        }

    


