import numpy as np
import pandas as pd
import os
from scipy.stats import bernoulli
np.random.seed(47)

n_units = 10000
n_covariate = 50
n_replication = 100

sigmoid_x_1 = 0.7
sigmoid_x_0 = 0.3
sigmoid_z_1 = 0.7
sigmoid_z_0 = 0.9


bete_a = np.random.randn(n_covariate)
bete_b = np.random.randn(n_covariate)
bete_c = np.random.randn(n_covariate)
bete_d = np.random.randn(n_covariate)
bete_x = np.random.randn(n_covariate)
bete_z = np.random.randn(n_covariate)
mu_x = 0
mu_z = 0.5
sigmoid_x = 0.5
sigmoid_z = 0.5
sigmoid_y = 0.5
r = 2


splits = 0.5
data_name = 'Toy'

def data_generate(path):
    if os.path.exists(path):
        for file_name in os.listdir(path):
            os.remove(os.path.join(path,file_name))
        os.rmdir(path)
    os.mkdir(path)
    x_columns = ['x{}'.format(col) for col in range(1,n_covariate+1)]
    treatment_column = 'treatment'
    yf_column = 'yf'
    ycf_column = 'ycf'
    mu0_column = 'mu0'
    mu1_column = 'mu1'

    for i in range(n_replication):
        if i % 10 == 0:
            print('Have generated {} replications for Simulation dataset!'.format(i))
        treatments = np.zeros(n_units,dtype=int)
        yf = np.zeros(n_units,dtype=float)
        ycf = np.zeros(n_units,dtype=float)
        mu0 = np.zeros(n_units,dtype=float)
        mu1 = np.zeros(n_units,dtype=float)
        units = np.zeros(shape=[n_units,n_covariate],dtype=float)

        for j in range(n_units):
            eta = np.random.normal(loc=0.5,scale=1.5,  size=n_covariate)
            scale_x = 0.7**2*eta + 0.3**2 * (1-eta)
            if sum(scale_x) < 0 :
                x = np.random.normal(loc=eta,scale=1,  size=n_covariate)
            else:
                x = np.random.normal(loc=eta,scale=sum(scale_x), size=n_covariate)

            treat_pre = sum(0.5 * bete_x * eta)
            p = sigmoid(treat_pre)

            # p = sigmoid(sum(z * bete_z))

            treatments[j] = bernoulli.rvs(size=1,p=p)[0]
            epsilon = np.random.normal(loc=0, scale=0.5, size=1)[0]

            # y_1 = sigmoid(sum(eta * 0.3)  + 6)
            # y_0 = sigmoid(sum(eta * 3)  - 6)
            # print(y_1,y_0)
            y_1 = sum(eta * bete_b) + epsilon + 3
            y_0 = sum(eta * bete_b) + epsilon - 3

            yf[j] = y_1 if treatments[j] else y_0
            ycf[j] = y_0 if treatments[j] else y_1

            mu1[j] = y_1
            mu0[j] = y_0
            units[j] = x

        data = {
            treatment_column: treatments,
            yf_column:yf,
            ycf_column:ycf,
            mu0_column: mu0,
            mu1_column: mu1
        }
        for index,key in enumerate(x_columns):
            data[key] = units[:,index]
        df = pd.DataFrame(data)
        ret = 'Treated ratios:{}, Control ratios:{}'.format(
            round(100 * (sum(treatments) / len(treatments))), 100-round(100 * (sum(treatments) / len(treatments)))
        )
        print(ret)
        df.to_csv(os.path.join(path,'{}{}.csv'.format(data_name,i+1)),index=False)

    print('Totally generated {} replications!'.format(n_replication))

def sigmoid(x):
    return 1. / (1. + np.exp(-x))

if __name__ == "__main__":
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('DFHTE') + len('DFHTE')]
    data_path = os.path.join(rootPath,'dataset')
    path = os.path.join(data_path,data_name)
    data_generate(path)
