#%%
# draw a function to plot the data! 

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np

# define drawing function
def plot_obs(x, x_real=None):

    # columns = [f"dim_{i + 1}" for i in range(x.shape[1])]
    columns = list(x.columns)
    if isinstance(x, torch.Tensor): x = x.detach().numpy()
    if x_real is None:
        df = pd.DataFrame(data=x, columns=columns)
        fig = sns.pairplot(df)
    else:
        if isinstance(x_real, torch.Tensor): x_real = x_real.numpy()
        num1 = x.shape[0]
        x_total = np.concatenate([x, x_real], 0)
        df = pd.DataFrame(data=x_total, columns=columns)
        df['Distribution'] = 'Real'
        df['Distribution'].iloc[:num1] = 'Gener'    
        fig = sns.pairplot(df, hue='Distribution',
                           plot_kws={'alpha': 0.3}, hue_order=['Real', 'Gener'],
                           diag_kind="hist")

    return fig


#%%
import pandas as pd
path = "PATH_TO_DATA"
dir = path+"/data" 
data_path = "{}/student-mat.csv".format(dir)
interventional0_data_gene_path = "{}/interventional0_data_gened.csv".format(dir)
interventional1_data_gene_path = "{}/interventional1_data_gened.csv".format(dir)

data = pd.read_csv(data_path, delimiter=';')
data['Grade'] = (data['G3'] + data['G2'] + data['G1']) / 3
data.drop(['G1','G2','G3'], axis=1, inplace=True)

real0 = data[data.sex=='F']
real1 = data[data.sex=='M']

gene0 = pd.read_csv(interventional0_data_gene_path, delimiter=',')
gene1 = pd.read_csv(interventional1_data_gene_path, delimiter=',')


#%%
# convert into numerical vars
from sklearn.preprocessing import StandardScaler

def get_dummy(df, features):
    df = pd.get_dummies(df, columns=features, drop_first=True)  #
    return df

num_vars = ['age', 'Medu', 'Fedu', 'traveltime', 'studytime', 'famrel', 'freetime', 
            'goout', 'Dalc', 'Walc', 'health', 'absences']

cat_vars = ['Mjob', 'Fjob', 'reason', 'guardian']
bi_vars = ['higher', 'internet', 'romantic', 'nursery', 'famsup', 'activities', 'schoolsup', 
            'paid', 'famsize', 'school', 'address', 'Pstatus']

for cur_data in [data, real0, real1, gene0, gene1]:
    cur_data.loc[cur_data['sex']=='M','sex'] = 1
    cur_data.loc[cur_data['sex']=='F','sex'] = 0

for feature in num_vars:
    scaler = StandardScaler()
    scaler.fit(data[feature].values.reshape(-1,1))
    for cur_data in [data, real0, real1, gene0, gene1]:
        cur_data[feature] = scaler.transform(cur_data[feature].values.reshape(-1,1))

data = pd.get_dummies(data, columns=bi_vars, drop_first=True, prefix=bi_vars)
data = pd.get_dummies(data, columns=cat_vars, drop_first=False, prefix=cat_vars)

gene0 = pd.get_dummies(gene0, columns=bi_vars, drop_first=True, prefix=bi_vars)
real0 = pd.get_dummies(real0, columns=bi_vars, drop_first=True, prefix=bi_vars)
gene1 = pd.get_dummies(gene1, columns=bi_vars, drop_first=True, prefix=bi_vars)
real1 = pd.get_dummies(real1, columns=bi_vars, drop_first=True, prefix=bi_vars)

gene0 = pd.get_dummies(gene0, columns=cat_vars, drop_first=False, prefix=cat_vars)
real0 = pd.get_dummies(real0, columns=cat_vars, drop_first=False, prefix=cat_vars)
gene1 = pd.get_dummies(gene1, columns=cat_vars, drop_first=False, prefix=cat_vars)
real1 = pd.get_dummies(real1, columns=cat_vars, drop_first=False, prefix=cat_vars)


bi_vars = [name for name in data.columns if any(map(lambda x: name.startswith(x), bi_vars))]
cat_vars = [name for name in data.columns if any(map(lambda x: name.startswith(x), cat_vars))]

#%%
# save the data!
gene0.to_csv("{}/interventional0_data_gene.csv".format(dir), sep=',', index=False, header=True)
gene1.to_csv("{}/interventional1_data_gene.csv".format(dir), sep=',', index=False, header=True)
real0.to_csv("{}/interventional0_data_real.csv".format(dir), sep=',', index=False, header=True)
real1.to_csv("{}/interventional1_data_real.csv".format(dir), sep=',', index=False, header=True)
data.to_csv("{}/observation_data.csv".format(dir), sep=',', index=False, header=True)


#%%
std = 0.1   # 0.3  # cat-vars: 0.1, num-vars: 0.05
include = 12
cur_set = 'bi_vars'
cur_idx = eval(cur_set)[6:include]

def dataRandomize(data, cur_idx, std):
    cur_data = data.copy()
    for var in cur_idx:
        cur_data[var] = data[var] + np.random.normal(0, std, data[var].shape)
    return cur_data

gened_data0 = dataRandomize(gene0, cur_idx, std)
gened_data1 = dataRandomize(gene1, cur_idx, std)
real_data0 = dataRandomize(real0, cur_idx, std)
real_data1 = dataRandomize(real1, cur_idx, std)

fig0 = plot_obs(gened_data0[cur_idx], real_data0[cur_idx])
fig1 = plot_obs(gened_data1[cur_idx], real_data1[cur_idx])
fig0.savefig(dir+"/interventional0_{}_{}.png".format(cur_set, include), dpi = 400)
fig1.savefig(dir+"/interventional1_{}_{}.png".format(cur_set, include), dpi = 400)

