#%%
# draw a function to plot the data! 

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np

# define drawing function
def plot_obs(x, x_real=None):

    # columns = [f"dim_{i + 1}" for i in range(x.shape[1])]
    columns = list(x.columns)
    if isinstance(x, torch.Tensor): x = x.detach().numpy()
    if x_real is None:
        df = pd.DataFrame(data=x, columns=columns)
        fig = sns.pairplot(df)
    else:
        if isinstance(x_real, torch.Tensor): x_real = x_real.numpy()
        num1 = x.shape[0]
        x_total = np.concatenate([x, x_real], 0)
        df = pd.DataFrame(data=x_total, columns=columns)
        df['Distribution'] = 'Real'
        df['Distribution'].iloc[:num1] = 'Gener'    
        fig = sns.pairplot(df, hue='Distribution',
                           plot_kws={'alpha': 0.3}, hue_order=['Real', 'Gener'],
                           diag_kind="hist")

    return fig


#%%
import pandas as pd
path = "D:\\universityWorks\\thirdYear\\Spring\\Aoqi Zuo\\0406\\0501code"
dir = path+"/data" 
data_path = "{}/credit_preprocessed.csv".format(dir)
interventional0_data_gene_path = "{}/interventional0_data_gened.csv".format(dir)
interventional1_data_gene_path = "{}/interventional1_data_gened.csv".format(dir)

data = pd.read_csv(data_path, delimiter=',',header=0)

real0 = data[data.Age==23]
real1 = data[data.Age==30]

gene0 = pd.read_csv(interventional0_data_gene_path, delimiter=',')
gene1 = pd.read_csv(interventional1_data_gene_path, delimiter=',')


#%%
# convert into numerical vars
from sklearn.preprocessing import StandardScaler

def get_dummy(df, features):
    df = pd.get_dummies(df, columns=features, drop_first=True)  #
    return df


# 去掉了Age, Loan_status, 'Loan_Percent_Income'
#               3               6               7
cat_vars = ['Home_Status', 'Loan_Intent', 'Loan_Grade']
bi_vars = ['Historical_Default']
num_vars = ['Income', 'Employment_Length', 'Loan_Amount', 
            'Interest_Rate', 'Historical_Length','Loan_Percent_Income']

loan_percent_income = ['Loan_Percent_Income', "Income", "Loan_Amount", 
                        "Home_Status", "Historical_Default"]
Age = ["Historical_Length", "Employment_Length", "Income",
        "Loan_Intent", "Home_Status",]

for cur_data in [data, real0, real1, gene0, gene1]:
    temp = cur_data['Age']
    temp[temp==23] = 0
    temp[temp==30] = 1

# for feature in num_vars:
#     scaler = StandardScaler()
#     scaler.fit(data[feature].values.reshape(-1,1))
#     for cur_data in [data, real0, real1, gene0, gene1]:
#         cur_data[feature] = scaler.transform(cur_data[feature].values.reshape(-1,1))


data['Home_Status'], Home_Status_idx2label = pd.factorize(data['Home_Status'])
data['Historical_Default'], Historical_Default_idx2label = pd.factorize(data['Historical_Default'])
data['Loan_Intent'], Loan_Intent_idx2label = pd.factorize(data['Loan_Intent'])

#%%
Home_Status_label2idx = {label: i for i, label in enumerate(Home_Status_idx2label)}
Historical_Default_label2idx = {label: i for i, label in enumerate(Historical_Default_idx2label)}
Loan_Intent_label2idx = {label: i for i, label in enumerate(Loan_Intent_idx2label)}

gene0['Home_Status'] = gene0['Home_Status'].apply(lambda x: Home_Status_label2idx[x])
gene0['Historical_Default'] = gene0['Historical_Default'].apply(lambda x: Historical_Default_label2idx[x])
gene0['Loan_Intent'] = gene0['Loan_Intent'].apply(lambda x: Loan_Intent_label2idx[x])

gene1['Home_Status'] = gene1['Home_Status'].apply(lambda x: Home_Status_label2idx[x])
gene1['Historical_Default'] = gene1['Historical_Default'].apply(lambda x: Historical_Default_label2idx[x])
gene1['Loan_Intent'] = gene1['Loan_Intent'].apply(lambda x: Loan_Intent_label2idx[x])

real0['Home_Status'] = real0['Home_Status'].apply(lambda x: Home_Status_label2idx[x])
real0['Historical_Default'] = real0['Historical_Default'].apply(lambda x: Historical_Default_label2idx[x])
real0['Loan_Intent'] = real0['Loan_Intent'].apply(lambda x: Loan_Intent_label2idx[x])

real1['Home_Status'] = real1['Home_Status'].apply(lambda x: Home_Status_label2idx[x])
real1['Historical_Default'] = real1['Historical_Default'].apply(lambda x: Historical_Default_label2idx[x])
real1['Loan_Intent'] = real1['Loan_Intent'].apply(lambda x: Loan_Intent_label2idx[x])

# data = pd.get_dummies(data, columns=bi_vars, drop_first=True, prefix=bi_vars)
# data = pd.get_dummies(data, columns=cat_vars, drop_first=False, prefix=cat_vars)

# gene0 = pd.get_dummies(gene0, columns=bi_vars, drop_first=True, prefix=bi_vars)
# real0 = pd.get_dummies(real0, columns=bi_vars, drop_first=True, prefix=bi_vars)
# gene1 = pd.get_dummies(gene1, columns=bi_vars, drop_first=True, prefix=bi_vars)
# real1 = pd.get_dummies(real1, columns=bi_vars, drop_first=True, prefix=bi_vars)

# gene0 = pd.get_dummies(gene0, columns=cat_vars, drop_first=False, prefix=cat_vars)
# real0 = pd.get_dummies(real0, columns=cat_vars, drop_first=False, prefix=cat_vars)
# gene1 = pd.get_dummies(gene1, columns=cat_vars, drop_first=False, prefix=cat_vars)
# real1 = pd.get_dummies(real1, columns=cat_vars, drop_first=False, prefix=cat_vars)


# bi_vars = [name for name in data.columns if any(map(lambda x: name.startswith(x), bi_vars))]
# cat_vars = [name for name in data.columns if any(map(lambda x: name.startswith(x), cat_vars))]

#%%
# save the data!
gene0.to_csv("{}/interventional0_data_gene.csv".format(dir), sep=',', index=False, header=True)
gene1.to_csv("{}/interventional1_data_gene.csv".format(dir), sep=',', index=False, header=True)
real0.to_csv("{}/interventional0_data_real.csv".format(dir), sep=',', index=False, header=True)
real1.to_csv("{}/interventional1_data_real.csv".format(dir), sep=',', index=False, header=True)
data.to_csv("{}/observation_data.csv".format(dir), sep=',', index=False, header=True)


#%%

# num_vars = ['Income', 'Loan_Amount', 'Loan_Percent_Income']
# num_vars = ['Employment_Length',  'Interest_Rate', 'Historical_Length']
# loan_percent_income = ['Loan_Percent_Income', "Historical_Default", "Loan_Amount", "Home_Status", "Income"]
# loan_percent_income = [name for name in data.columns if any(map(lambda x: name.startswith(x), loan_percent_income))]

# std = 0.01   # 0.3  # cat-vars: 0.1, num-vars: 0.05
# cur_set =  'loan_percent_income' # 'Age' #'loan_percent_income'
# stds = [0, 0, 0, 0.1, 0.1] # loan_percent_income
cur_set = "Age"
stds = [0.1, 0.05, 0, 0.1, 0.1] # age
pre = 0
include = len(eval(cur_set))
# include = 12
cur_idx = eval(cur_set)[pre:include]

def dataRandomize(data, cur_idx, stds):
    cur_data = data.copy()
    for var,cur_std in zip(cur_idx, stds):
        if cur_std != 0:
            cur_data[var] = data[var] + np.random.normal(0, cur_std, data[var].shape)
    return cur_data

gened_data0 = dataRandomize(gene0, cur_idx, stds)
gened_data1 = dataRandomize(gene1, cur_idx, stds)
real_data0 = dataRandomize(real0, cur_idx, stds)
real_data1 = dataRandomize(real1, cur_idx, stds)


# # fig0 = plot_obs(gened_data0[cur_idx], real_data0[cur_idx])
# # fig1 = plot_obs(gened_data1[cur_idx], real_data1[cur_idx])
fig0 = plot_obs(gened_data0[cur_idx].sample(n=real_data0.shape[0]), real_data0[cur_idx])
fig1 = plot_obs(gened_data1[cur_idx].sample(n=real_data1.shape[0]), real_data1[cur_idx])
# # fig0.savefig(dir+"/interventional0_{}_{}.pdf".format(cur_set,include), format='pdf', dpi = 400, bbox_inches='tight')
# # fig1.savefig(dir+"/interventional1_{}_{}.pdf".format(cur_set,include), format='pdf', dpi = 400, bbox_inches='tight')

# %%

# %%
fig0.savefig(dir+"/interventional0_{}_{}.png".format(cur_set,include), dpi = 400)
fig1.savefig(dir+"/interventional1_{}_{}.png".format(cur_set,include), dpi = 400)

#%%
# import patchworklib as pw
# pw.overwrite_axisgrid()

np.random.seed(0)
x_real = real_data0[cur_idx]
x = gened_data0[cur_idx].sample(n=real_data0.shape[0])

columns = list(x.columns)
num1 = x.shape[0]
x_total = np.concatenate([x, x_real], 0)
df = pd.DataFrame(data=x_total, columns=columns)
df['Distribution'] = 'Real'
df['Distribution'].iloc[:num1] = 'Gener' 

# sns.axes_style("white")
# sns.set(font_scale=2)
# fig = sns.pairplot(df, hue='Distribution',
#                     plot_kws={'alpha': 0.3}, 
#                     hue_order=['Real', 'Gener'],
#                     diag_kind="hist")


g = sns.PairGrid(df, hue='Distribution', hue_order=['Real', 'Gener'])
g.map_diag(sns.histplot)
g.map_offdiag(sns.scatterplot)
g.add_legend()

g.tick_params(labelsize=10)
g.fig.set_title("Plot", fontsize = 20)

# legend = g._legend

# g = sns.PairGrid(df, x_vars=cur_idx, hue="Distribution", height=4)
# g.map(sns.regplot, color=".3")
# g.set(ylim=(-1, 11), yticks=[0, 5, 10])

# fig  = pw.load_seaborngrid(fig)
# fig.move_legend("upper left", bbox_to_anchor=(0.08,1.01))
# plt.show()

# fig.map_diag(sns.histplot)
# fig.set_xlabel("X-Axis", fontsize = 20)
# fig.set_ylabel("Y-Axis", fontsize = 20)
# fig.set_title("Plot", fontsize = 20)

# for ax in fig.axes_dict.values():
    # ax.legend(loc='upper left', bbox_to_anchor=(0, 1.3), fontsize = 20)

# plt.legend(labels=['Real', 'Gener'], fontsize = 20)
# plt.tick_params(axis='both', which='minor', labelsize=12)

