import numpy as np
import pandas as pd
import os
from scipy.stats import bernoulli
import matplotlib.image as mpimg
from PIL import Image
import PIL
n_units = 1000
n_covariate = 10
n_replication = 10

loc = 0
scale_x = 0.5
scale_y = 0.5
bete_1 = np.random.randn(n_covariate)
bete_0 = np.random.randn(n_covariate)
bete_t = np.random.randn(n_covariate)
bete_age = np.random.randn(n_covariate,116)
bete_sex = np.random.randn(n_covariate,2)

# SIM-X 2000 10
# SIM-Y 2000 20
# SIM-C 5000 20
theta = 0.1
epsilon = 0.1
data_name = 'Multimodal'
def data_generate(path):
    if os.path.exists(path):
        for file_name in os.listdir(path):
            os.remove(os.path.join(path,file_name))
        os.rmdir(path)
    os.mkdir(path)
    x_columns = ['x{}'.format(col) for col in range(1,n_covariate+1)]
    image_columns = ['image{}'.format(col) for col in range(1, 900*3 + 1)]
    treatment_column = 'treatment'
    yf_column = 'yf'
    ycf_column = 'ycf'
    mu0_column = 'mu0'
    mu1_column = 'mu1'
    age_column = 'age'

    for i in range(n_replication):
        if i % 10 == 0:
            print('Have generated {} replications for Simulation dataset!'.format(i))
        treatments = np.zeros(n_units,dtype=int)
        ages = np.zeros(n_units, dtype=int)
        yf = np.zeros(n_units,dtype=float)
        ycf = np.zeros(n_units,dtype=float)
        mu0 = np.zeros(n_units,dtype=float)
        mu1 = np.zeros(n_units,dtype=float)
        x = np.zeros(shape=[n_units,n_covariate],dtype=float)
        ages = np.zeros(shape=[n_units,900*3],dtype=float)

        for j in range(n_units):
            unit = np.random.normal(loc=loc,scale=scale_x,size=n_covariate)
            epsilon_i = np.random.normal(loc=0,scale=scale_y,size=1)[0]
            p = sigmoid(sum(unit * bete_t))
            # print(unit)
            # print(bete_t)
            # print('result:{}'.format(sum(unit * bete_t)))
            treatments[j] = bernoulli.rvs(size=1,p=p)[0]
            # print(treatments[j])
            y_1 = sum(unit * bete_t) + sum(unit * bete_0) + epsilon_i + theta
            y_0 = sum(unit * bete_0) + epsilon_i

            y_1 = sigmoid(y_1)
            y_0 = sigmoid(y_0)

            yf[j] = y_1 if treatments[j] else y_0
            ycf[j] = y_0 if treatments[j] else y_1

            mu1[j] = y_1
            mu0[j] = y_0
            # pertu = np.random.uniform(-epsilon,epsilon,size=n_covariate)
            x[j] = unit
            cov = unit.reshape((1,-1))
            logits = np.matmul(cov,bete_age).squeeze()
            outputs = softmax(logits)
            unit_age = np.argmax(outputs) + 1
            unit_sex = np.argmax(sigmoid(np.matmul(cov,bete_sex).squeeze()))
            image_embedding = read_image(unit_age,unit_sex)
            ages[j] = image_embedding

        data = {
            treatment_column: treatments,
            yf_column:yf,
            ycf_column:ycf,
            mu0_column: mu0,
            mu1_column: mu1,
        }
        for index,key in enumerate(x_columns):
            data[key] = x[:,index]
        for index,key in enumerate(image_columns):
            data[key] = ages[:,index]
        df = pd.DataFrame(data)
        ret = 'Treated ratios:{}, Control ratios:{}'.format(
            round(100 * (sum(treatments) / len(treatments))), 100-round(100 * (sum(treatments) / len(treatments)))
        )
        print(ret)
        df.to_csv(os.path.join(path,'{}{}.csv'.format(data_name,i+1)),index=False)

    print('Totally generated {} replications!'.format(n_replication))

def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def softmax(x, axis=None):
    x_exp = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return x_exp / np.sum(x_exp, axis=axis, keepdims=True)

def read_image(age,sex):
    path = '/Users/wangzhenlei/Downloads/part1'
    for file in os.listdir(path):
        if file.startswith('{}_{}'.format(age,sex)):
            file_path = os.path.join(path, file)
            image = Image.open(file_path)
            image = image.resize((30,30))
            data = np.array(image.getdata()).flatten()[0:2700]
            data = (data - data.min()) / (data.max()-data.min())
            return data
    file_path = os.path.join(path,'24_0_0_20170117150006042.jpg')
    image = Image.open(file_path)
    image = image.resize((30, 30))
    data = np.array(image.getdata()).flatten()[0:2700]
    # print(len(data))
    data = (data - data.min()) / (data.max() - data.min())
    return data



def read_ihdp(path):
    train_path = os.path.join(path,'ihdp_npci_1-1000.train.npz')
    test_path = os.path.join(path,'ihdp_npci_1-1000.test.npz')

    train_datas = np.load(train_path,allow_pickle=True)
    test_datas = np.load(test_path,allow_pickle=True)

    t = np.concatenate([train_datas['t'],test_datas['t']],axis=0).astype(int)
    x = np.concatenate([train_datas['x'],test_datas['x']],axis=0)
    yf = np.concatenate([train_datas['yf'],test_datas['yf']],axis=0)
    ycf = np.concatenate([train_datas['ycf'],test_datas['ycf']],axis=0)
    mu0 = np.concatenate([train_datas['mu0'],test_datas['mu0']],axis=0)
    mu1 = np.concatenate([train_datas['mu1'], test_datas['mu1']], axis=0)

    ihdp_path = os.path.join(path,'IHDP')
    if os.path.exists(ihdp_path):
        for file_name in os.listdir(ihdp_path):
            os.remove(os.path.join(ihdp_path,file_name))
        os.rmdir(ihdp_path)
    os.mkdir(ihdp_path)
    print(t.shape)
    print(x.shape)
    for i in range(t.shape[1]):
        file_path = os.path.join(ihdp_path,'IHDP{}.csv'.format(i+1))
        replication_x = x[:,:,i]
        data = {
            "treatment": t[:,i],
            "yf": yf[:,i],
            "ycf": ycf[:,i],
            "mu0": mu0[:,i],
            "mu1": mu1[:,i]
        }
        n_covar = replication_x.shape[1]
        for index in range(n_covar):
            key = 'x{}'.format(index+1)
            data[key] = replication_x[:,index]

        df = pd.DataFrame(data)
        # print(df)
        df.to_csv(file_path, index=False)

if __name__ == "__main__":
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('Causally_Rebuttle') + len('Causally_Rebuttle')]
    data_path = os.path.join(rootPath,'dataset')
    path = os.path.join(data_path,data_name)
    data_generate(path)
