import numpy as np
import pandas as pd
import os
from scipy.stats import bernoulli


n_units = 10000
n_covariate = 50
n_replication = 100

loc = 0
scale_x = 0.5
scale_y = 0.5
bete_1 = np.random.randn(n_covariate)
bete_0 = np.random.randn(n_covariate)
bete_t = np.random.randn(n_covariate)


theta = 0.1
epsilon = 0.1
data_name = 'SIM'
def data_generate(path):
    if os.path.exists(path):
        for file_name in os.listdir(path):
            os.remove(os.path.join(path,file_name))
        os.rmdir(path)
    os.mkdir(path)
    x_columns = ['x{}'.format(col) for col in range(1,n_covariate+1)]
    treatment_column = 'treatment'
    yf_column = 'yf'
    ycf_column = 'ycf'
    mu0_column = 'mu0'
    mu1_column = 'mu1'

    for i in range(n_replication):
        if i % 10 == 0:
            print('Have generated {} replications for Simulation dataset!'.format(i))
        treatments = np.zeros(n_units,dtype=int)
        yf = np.zeros(n_units,dtype=float)
        ycf = np.zeros(n_units,dtype=float)
        mu0 = np.zeros(n_units,dtype=float)
        mu1 = np.zeros(n_units,dtype=float)
        x = np.zeros(shape=[n_units,n_covariate],dtype=float)

        for j in range(n_units):
            unit = np.random.normal(loc=loc,scale=scale_x,size=n_covariate)
            epsilon_i = np.random.normal(loc=0,scale=scale_y,size=1)[0]
            p = sigmoid(sum(unit * bete_t))
            # print(unit)
            # print(bete_t)
            # print('result:{}'.format(sum(unit * bete_t)))
            treatments[j] = bernoulli.rvs(size=1,p=p)[0]
            # print(treatments[j])
            y_1 = sum(unit * bete_t) + sum(unit * bete_0) + epsilon_i + theta
            y_0 = sum(unit * bete_0) + epsilon_i

            yf[j] = y_1 if treatments[j] else y_0
            ycf[j] = y_0 if treatments[j] else y_1

            mu1[j] = y_1
            mu0[j] = y_0
            # pertu = np.random.uniform(-epsilon,epsilon,size=n_covariate)
            x[j] = unit

        data = {
            treatment_column: treatments,
            yf_column:yf,
            ycf_column:ycf,
            mu0_column: mu0,
            mu1_column: mu1
        }
        for index,key in enumerate(x_columns):
            data[key] = x[:,index]
        df = pd.DataFrame(data)
        ret = 'Treated ratios:{}, Control ratios:{}'.format(
            round(100 * (sum(treatments) / len(treatments))), 100-round(100 * (sum(treatments) / len(treatments)))
        )
        print(ret)
        # df.to_csv(os.path.join(path,'{}{}.csv'.format(data_name,i+1)),index=False)

    print('Totally generated {} replications!'.format(n_replication))

def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def read_ihdp(path):
    train_path = os.path.join(path,'ihdp_npci_1-1000.train.npz')
    test_path = os.path.join(path,'ihdp_npci_1-1000.test.npz')

    train_datas = np.load(train_path,allow_pickle=True)
    test_datas = np.load(test_path,allow_pickle=True)

    t = np.concatenate([train_datas['t'],test_datas['t']],axis=0).astype(int)
    x = np.concatenate([train_datas['x'],test_datas['x']],axis=0)
    yf = np.concatenate([train_datas['yf'],test_datas['yf']],axis=0)
    ycf = np.concatenate([train_datas['ycf'],test_datas['ycf']],axis=0)
    mu0 = np.concatenate([train_datas['mu0'],test_datas['mu0']],axis=0)
    mu1 = np.concatenate([train_datas['mu1'], test_datas['mu1']], axis=0)

    ihdp_path = os.path.join(path,'IHDP')
    if os.path.exists(ihdp_path):
        for file_name in os.listdir(ihdp_path):
            os.remove(os.path.join(ihdp_path,file_name))
        os.rmdir(ihdp_path)
    os.mkdir(ihdp_path)
    print(t.shape)
    print(x.shape)
    for i in range(t.shape[1]):
        file_path = os.path.join(ihdp_path,'IHDP{}.csv'.format(i+1))
        replication_x = x[:,:,i]
        data = {

            "treatment": t[:,i],
            "yf": yf[:,i],
            "ycf": ycf[:,i],
            "mu0": mu0[:,i],
            "mu1": mu1[:,i]
        }
        n_covar = replication_x.shape[1]
        for index in range(n_covar):
            key = 'x{}'.format(index+1)
            data[key] = replication_x[:,index]

        df = pd.DataFrame(data)
        # print(df)
        df.to_csv(file_path, index=False)

if __name__ == "__main__":
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('Causally') + len('Causally')]
    data_path = os.path.join(rootPath,'dataset')
    path = os.path.join(data_path,data_name)
    data_generate(path)
    # read_ihdp(data_path)