# load trade
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

dataset_names = ['trade']

# load data, 
p_values_regression = {}
p_values_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/p_values.pkl', 'rb') as f:
        p_values_regression[dataset_name] = pickle.load(f)
        p_values_regression_key_names = p_values_regression[dataset_name].keys()
        for key in p_values_regression_key_names:
            # only keep key with fusion, 
            if 'fusion' not in key:
                continue
            # get the ratio that pvalue > 0.05, 
            list_p = p_values_regression[dataset_name][key]
            # print(dataset_name, key, p_values_regression[dataset_name][key])
            ratio = len([x for x in list_p if x >= 0.05]) / len(list_p)
            print(dataset_name, key, ratio)

print("p_values_regression:", p_values_regression)
input()

errors_regression = {}
errors_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors.pkl', 'rb') as f:
        errors_regression[dataset_name] = pickle.load(f)
        errors_regression_key_names = errors_regression[dataset_name].keys()

# plot Error distribution,
import seaborn as sns
import pandas as pd
errors = errors_regression['trade']
key_names = ['roi_last10days_fusion', 'catboost_fusion']
for key_name in key_names:
    for i in range(5):
        # winsorize the error
        # 1. get the 95% quantile,
        # def winsorize(data, lower=0.001, upper=0.999): 
        #     lower_bound = np.quantile(data, lower)
        #     upper_bound = np.quantile(data, upper)
        #     return np.clip(data, lower_bound, upper_bound)
        # plot_error = winsorize(errors[key_name][i])
        plot_error = errors[key_name][i]
        print(key_name, i, np.mean(plot_error), np.std(plot_error))
        # QQ plot,
        import scipy.stats as stats
        stats.probplot(plot_error, dist="norm", plot=plt)
        plt.title('Q-Q plot of ' + key_name + ' ' + str(i+1))
        plt.show()
        # # 2. plot the error distribution,
        # df = pd.DataFrame(plot_error)
        # plt.figure(figsize=(10, 8))
        # sns.histplot(data=df, kde=True)
        # plt.ylabel('Frequency')
        # plt.xlabel('Error')
        # plt.title('Error Distribution of ' + key_name + ' ' + str(i))
        # plt.show()

errors_s_regression = {}
errors_s_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors_s.pkl', 'rb') as f:
        errors_s_regression[dataset_name] = pickle.load(f)
        errors_s_regression_key_names = errors_s_regression[dataset_name].keys()

errors_p_regression = {}
errors_p_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors_p.pkl', 'rb') as f:
        errors_p_regression[dataset_name] = pickle.load(f)
        errors_p_regression_key_names = errors_p_regression[dataset_name].keys()

errors_c_regression = {}
errors_c_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors_c.pkl', 'rb') as f:
        errors_c_regression[dataset_name] = pickle.load(f)
        errors_c_regression_key_names = errors_c_regression[dataset_name].keys()

errors_cp_regression = {}
errors_cp_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors_cp.pkl', 'rb') as f:
        errors_cp_regression[dataset_name] = pickle.load(f)
        errors_cp_regression_key_names = errors_cp_regression[dataset_name].keys()

errors_sc_regression = {}
errors_sc_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors_sc.pkl', 'rb') as f:
        errors_sc_regression[dataset_name] = pickle.load(f)
        errors_sc_regression_key_names = errors_sc_regression[dataset_name].keys()

errors_sp_regression = {}
errors_sp_regression_key_names = None
for dataset_name in dataset_names:
    with open('./expdata/' + dataset_name + '/errors_sp.pkl', 'rb') as f:
        errors_sp_regression[dataset_name] = pickle.load(f)
        errors_sp_regression_key_names = errors_sp_regression[dataset_name].keys()

# print(errors_regression)

trade_errors = errors_regression['trade']
print(trade_errors.keys())
# trade_errors_key_names = trade_errors.keys()
show_names = ['roi_last10days_fusion', 'roi_last10days_lr', 'roi_last10days_nn', 
              'linear_lr', 'linear_nn', 'linear_fusion', 
              'catboost_lr', 'catboost_nn', 'catboost_fusion', ]
# print(trade_errors_key_names)
# trade_errors_values = []

errors_rmse_regression = {}
errors_rmse_regression_list = {}
for key_name in show_names:
    dataset_key_rmse_list = []
    errors_rmse_regression_list[key_name] = []
    n_exp = len(trade_errors[key_name])
    for i in range(n_exp): # rmse of each experiment, 
        array = trade_errors[key_name][i]
        rmse = np.sqrt(np.mean(np.square(array))) # rmse of array, 
        dataset_key_rmse_list.append(rmse)
        errors_rmse_regression_list[key_name].append(rmse)
    # mean of rmse, 
    errors_rmse_regression[key_name] = np.mean(dataset_key_rmse_list) # mean of rmse, 

print(errors_rmse_regression)
input("Press Enter to continue...")

# print(errors_rmse_regression)
# plot, 
df = pd.DataFrame(errors_rmse_regression_list)
# print(df)
plt.figure(figsize=(10, 8))
sns.barplot(data=df) # 95% confidence interval, 
plt.xticks(rotation=90)
plt.ylabel('RMSE')
plt.xlabel('Evaluation Model')
# left is label as 'Regression', 
plt.title('Evaluate Regression Model by RMSE')
plt.show()
print("catboost:", errors_rmse_regression['catboost_fusion'])
print("catboost:", 
    -(errors_rmse_regression['catboost_fusion']-errors_rmse_regression['roi_last10days_fusion'])/errors_rmse_regression['roi_last10days_fusion']
    )
print("linear:",
    -(errors_rmse_regression['linear_fusion']-errors_rmse_regression['roi_last10days_fusion'])/errors_rmse_regression['roi_last10days_fusion']
    )
# ----------------------------------------------------------------------------------------
# Shapley value of union (subject, proxy),
# get outcomes,
# regression, rmse, 
shapleys_rmse  = {'s': 0, 'p': 0, 'c': 0}
outcomes_rmse = {'s': 0, 'p': 0, 'c': 0, 'sc': 0, 'sp': 0, 'cp':0, 'zero': 0}
rmse_s, rmse_p, rmse_sp, rmse_zero, rmse_c, rmse_cp, rmse_sc, rmse_scp = [], [], [], [], [], [], [], []
print("s:", errors_s_regression['trade'].keys(), \
      'c:', errors_c_regression['trade'].keys(), \
      'p:', errors_p_regression['trade'].keys(), \
      'sc:', errors_sc_regression['trade'].keys(), \
      'sp:', errors_sp_regression['trade'].keys(), \
      'cp:', errors_cp_regression['trade'].keys(), \
      'scp:', errors_regression['trade'].keys()
      )
for i in range(len(errors_s_regression['trade']['catboost_fusion'])): # 30 exp, 
    s = -np.sqrt(np.mean(np.square(errors_s_regression['trade']['catboost_fusion'][i])))
    p = -np.sqrt(np.mean(np.square(errors_p_regression['trade']['catboost_fusion'][i])))
    c = -np.sqrt(np.mean(np.square(errors_c_regression['trade']['catboost_fusion'][i])))
    cp = -np.sqrt(np.mean(np.square(errors_cp_regression['trade']['catboost_fusion'][i])))
    sc = -np.sqrt(np.mean(np.square(errors_sc_regression['trade']['catboost_fusion'][i])))
    sp = -np.sqrt(np.mean(np.square(errors_sp_regression['trade']['catboost_fusion'][i])))
    spc = -np.sqrt(np.mean(np.square(errors_regression['trade']['catboost_fusion'][i])))
    zero = -np.sqrt(np.mean(np.square(errors_regression['trade']['roi_last10days_fusion'][i])))
    rmse_s.append(s)
    rmse_p.append(p)
    rmse_c.append(c)
    rmse_cp.append(cp)
    rmse_sc.append(sc)
    rmse_sp.append(sp)
    rmse_zero.append(zero)
    rmse_scp.append(spc)

# mean of rmse,
outcomes_rmse['s'] = np.mean(rmse_s)
outcomes_rmse['p'] = np.mean(rmse_p)
outcomes_rmse['c'] = np.mean(rmse_c)
outcomes_rmse['cp'] = np.mean(rmse_cp)
outcomes_rmse['sc'] = np.mean(rmse_sc)
outcomes_rmse['sp'] = np.mean(rmse_sp)
outcomes_rmse['zero'] = np.mean(rmse_zero)
outcomes_rmse['scp'] = np.mean(rmse_scp)
# print(dataset_name, outcomes_rmse[dataset_name]), 
shapleys_rmse['s'] = 1/3 * (outcomes_rmse['s'] - outcomes_rmse['zero']) \
            + 1/6 * (outcomes_rmse['sc'] - outcomes_rmse['c']) \
            + 1/6 * (outcomes_rmse['sp'] - outcomes_rmse['p']) \
            + 1/3 * (outcomes_rmse['scp'] - outcomes_rmse['cp'])

shapleys_rmse['p'] = 1/3 * (outcomes_rmse['p'] - outcomes_rmse['zero']) \
            + 1/6 * (outcomes_rmse['sp'] - outcomes_rmse['s']) \
            + 1/6 * (outcomes_rmse['cp'] - outcomes_rmse['c']) \
            + 1/3 * (outcomes_rmse['scp'] - outcomes_rmse['sc'])

shapleys_rmse['c'] = 1/3 * (outcomes_rmse['c'] - outcomes_rmse['zero']) \
            + 1/6 * (outcomes_rmse['cp'] - outcomes_rmse['p']) \
            + 1/6 * (outcomes_rmse['sc'] - outcomes_rmse['s']) \
            + 1/3 * (outcomes_rmse['scp'] - outcomes_rmse['sp'])

shapleys_rmse['zero'] = outcomes_rmse['zero']

# plot, 
print(outcomes_rmse)
df = pd.DataFrame(shapleys_rmse, index=[0])
print(df)
plt.figure(figsize=(10, 8))
sns.barplot(data=df) # 95% confidence interval, 
plt.xticks(rotation=90)
plt.ylabel('Shapley Value')
plt.xlabel('Evaluation Model')
# left is label as 'Regression', 
plt.title('Evaluate Regression Model by Shapley Value')
plt.show()

# base is -5.235153519028588, total is -0.5534320712034532,
# contribution of s is 1.676271, contribution of p is 1.675149, contribution of c is 1.330301,
# plot a pie chart, 
labels = ['subject', 'proxy', 'condition']
sizes = [shapleys_rmse['s'], shapleys_rmse['p'], shapleys_rmse['c']]
colors = ['gold', 'yellowgreen', 'lightcoral']
explode = (0, 0, 0)
plt.figure(figsize=(10, 8))
plt.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%',  startangle=140)
plt.axis('equal')
plt.title('Shapley Value of Inputs (Uplift ROI is 4.6817)')
plt.show()
# ----------------------------------------------------------------------------------------

# 0.905002  0.905709  0.729403
# percents are 0.905002, 0.905709, 0.729403, sum is 2.540114
# the 100% of first is 0.905002/2.540114 = 0.356,
# the 100% of second is 0.905709/2.540114 = 0.356,
# the 100% of third is 0.729403/2.540114 = 0.288,