import numpy as np
import pandas as pd
import os
from scipy.stats import bernoulli


n_units = 10000
n_covariate = 50
n_replication = 100


bete_a = np.random.randn(n_covariate)
bete_b = np.random.randn(n_covariate)
bete_c = np.random.randn(n_covariate)
bete_d = np.random.randn(n_covariate)
bete_x = np.random.randn(n_covariate)
bete_z = np.random.randn(n_covariate)
mu_x = 0
mu_z = 0.5
sigmoid_x = 0.5
sigmoid_z = 0.5
sigmoid_y = 0.5
r = 2


splits = 0.5
data_name = 'SIM-Z'

def data_generate(path):
    if os.path.exists(path):
        for file_name in os.listdir(path):
            os.remove(os.path.join(path,file_name))
        os.rmdir(path)
    os.mkdir(path)
    x_columns = ['x{}'.format(col) for col in range(1,n_covariate+1)]
    treatment_column = 'treatment'
    yf_column = 'yf'
    ycf_column = 'ycf'
    mu0_column = 'mu0'
    mu1_column = 'mu1'

    for i in range(n_replication):
        if i % 10 == 0:
            print('Have generated {} replications for Simulation dataset!'.format(i))
        treatments = np.zeros(n_units,dtype=int)
        yf = np.zeros(n_units,dtype=float)
        ycf = np.zeros(n_units,dtype=float)
        mu0 = np.zeros(n_units,dtype=float)
        mu1 = np.zeros(n_units,dtype=float)
        units = np.zeros(shape=[n_units,n_covariate],dtype=float)

        for j in range(n_units):
            x = np.random.normal(loc=mu_x,scale=sigmoid_x,  size=n_covariate)
            z = np.random.normal(loc=mu_z, scale=sigmoid_z, size=n_covariate)
            epsilon = np.random.normal(loc=0,scale=sigmoid_y,size=1)[0]
            p = sigmoid((1-splits) * sum(x * bete_x) + splits * sum(z * bete_z))
            # p = sigmoid(sum(z * bete_z))

            treatments[j] = bernoulli.rvs(size=1,p=p)[0]

            # y_1 = sum(x * bete_a) + sum(z * bete_b) + epsilon - r
            # y_0 = sum(x * bete_a) + sum(z * bete_b) + epsilon + r

            y_1 = sum(x * bete_a) + sum(z * bete_b) + epsilon
            y_0 = sum(x * bete_c) + sum(z * bete_d) + epsilon + r + sum(x * bete_a) + sum(z * bete_b)

            yf[j] = y_1 if treatments[j] else y_0
            ycf[j] = y_0 if treatments[j] else y_1

            mu1[j] = y_1
            mu0[j] = y_0
            units[j] = x

        data = {
            treatment_column: treatments,
            yf_column:yf,
            ycf_column:ycf,
            mu0_column: mu0,
            mu1_column: mu1
        }
        for index,key in enumerate(x_columns):
            data[key] = units[:,index]
        df = pd.DataFrame(data)
        ret = 'Treated ratios:{}, Control ratios:{}'.format(
            round(100 * (sum(treatments) / len(treatments))), 100-round(100 * (sum(treatments) / len(treatments)))
        )
        print(ret)
        df.to_csv(os.path.join(path,'{}{}.csv'.format(data_name,i+1)),index=False)

    print('Totally generated {} replications!'.format(n_replication))

def sigmoid(x):
    return 1. / (1. + np.exp(-x))

if __name__ == "__main__":
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('Causally') + len('Causally')]
    data_path = os.path.join(rootPath,'dataset')
    path = os.path.join(data_path,data_name)
    data_generate(path)
