# coding: utf-8
import os

import matplotlib.pyplot as plt
import numpy as np
from joblib import load

from src.utils import plot_performance_same_strategy, plot_performance_benchmark

####################################################################################
####################################################################################
############################# AREA OF INPUT PARAMETERS #############################
####################################################################################
####################################################################################
##### Environment Parameters
PATH = "." # Root directory, should be the same path this "README.md" file locates
PATH_DATA = f"{PATH}/data" # Path for data
PATH_MODELS = f"{PATH}/models" # Path for models
PATH_PICS = f"{PATH}/pics" # Path for graphs

##### Parameters for Simulated andits
list_seed_simulation = list(range(1986, 1996))  # To reproduce the results, use simulated seed from 1986 to 1996
T = 100000  # T, use 100000 to reproduce the results
with_union_bound = False
prefix_union_bound = 'ub' if with_union_bound else 'no_ub'

if os.path.isdir(PATH_PICS):
    pass
else:
    os.makedirs(PATH_PICS)

####################################################################################
####################################################################################
####################### Collect the Simulated Results ##############################
####################################################################################
####################################################################################
list_batch_belief = [320, 2200, 5600]
list_ucb_multipler = np.round(np.geomspace(0.0005, 0.25, 10), 4)
list_ucb_multipler_str = ["0" + str(x)[2:] for x in list_ucb_multipler]

T_range = np.array(list(range(1, T + 1)))

dict_results_benchmark = {}
dict_results_regrets_belief = {}

### Collect the results
for batch_ in list_batch_belief:
    print(batch_)
    dict_results_regrets_ = {}
    for C_ in list_ucb_multipler:
        print(C_)
        C_str_ = "0" + str(C_)[2:]

        for index_, seed_ in enumerate(list_seed_simulation):
            print(seed_)
            benchmark_ = load(f"{PATH_MODELS}/random_seed_{seed_}/dict_kpi_rewards.pkl")
            linear_bandits_belief_ = load(f"{PATH_MODELS}/random_seed_{seed_}/linear_bandits_Belief_B{batch_}_C{C_str_}_{prefix_union_bound}.pkl")
            reward_benchmark_ = np.cumsum(benchmark_['optim_believed_reward_belief'])
            regret_linear_belief_ = np.cumsum(benchmark_['optim_believed_reward_belief'] - np.array(linear_bandits_belief_.believed_rewards))

            if index_ == 0:
                reward_benchmark_f = reward_benchmark_
                regret_linear_belief_f = regret_linear_belief_
            else:
                reward_benchmark_f = np.column_stack([reward_benchmark_f, reward_benchmark_])
                regret_linear_belief_f = np.column_stack([regret_linear_belief_f, regret_linear_belief_])

        dict_results_regrets_[f"Linear Bandits with Belief --- C={C_}"] = regret_linear_belief_f.mean(axis=1)
        dict_results_regrets_[f"Linear Bandits with Belief --- C={C_} - std"] = regret_linear_belief_f.std(axis=1)
    dict_results_regrets_belief[batch_] = dict_results_regrets_

for C_ in list_ucb_multipler:
    print(C_)
    C_str_ = "0" + str(C_)[2:]

    for index_, seed_ in enumerate(list_seed_simulation):
        print(seed_)
        benchmark_ = load(f"{PATH_MODELS}/random_seed_{seed_}/dict_kpi_rewards.pkl")
        linear_bandits_ = load(f"{PATH_MODELS}/random_seed_{seed_}/linear_bandits_C{C_str_}.pkl")
        reward_benchmark_ = np.cumsum(benchmark_['optim_believed_reward_belief'])

        regret_linear_ = np.cumsum(
            benchmark_['optim_believed_reward_belief'] - np.array(linear_bandits_.believed_rewards))

        if index_ == 0:
            regret_linear_f = regret_linear_

        else:
            regret_linear_f = np.column_stack([regret_linear_f, regret_linear_])

    dict_results_benchmark[f"Linear Bandits --- C={C_}"] = regret_linear_f.mean(axis=1)
    dict_results_benchmark[f"Linear Bandits --- C={C_} - std"] = regret_linear_f.std(axis=1)
    print("----------------------------")

list_dict_belief_results = []
for batch_ in list_batch_belief:
    list_dict_belief_results.append(dict_results_regrets_belief[batch_])

####################################################################################
####################################################################################
####################### Plot and export the results ################################
####################################################################################
####################################################################################

### Export UCB Belief + Linear Bandits with Hyperparameters
fig, axs = plt.subplots(2, 2, figsize=(26, 24))
axs = axs.flatten()
fig.subplots_adjust(
    left=0.06,
    right=0.98,
    bottom=0.06,
    top=0.88,
    wspace=0.06,
    hspace=0.05
)
dict_results_linear = {}
for key_ in dict_results_benchmark:
    if 'with Belief' not in key_:
        dict_results_linear[key_] = dict_results_benchmark[key_]

plot_performance_same_strategy(dict_results_benchmark, T_range=T_range, sample_size=len(list_seed_simulation),
                               prefix_name='Linear Bandits',
                               keep_legend = False, ax=axs[0], sampling = 0.1)

for dict_results_, ax_ in zip(list_dict_belief_results, axs[1:]):
    plot_performance_same_strategy(dict_results_, T_range=T_range, sample_size=len(list_seed_simulation),
                                   prefix_name='with Belief',
                                   keep_legend = False, ax=ax_, sampling = 0.1)

handles, labels = ax_.get_legend_handles_labels()
labels = [x.split("---")[1].strip() for x in labels]
fig.legend(handles[:2], labels[:2], bbox_to_anchor=(0.2, 0.94),
           prop={'size': 20}, title_fontsize = 20)
fig.legend(handles[2:4], labels[2:4], bbox_to_anchor=(0.3, 0.94),
           prop={'size': 20}, title_fontsize = 20)
fig.legend(handles[4:6], labels[4:6], bbox_to_anchor=(0.4, 0.94),
           prop={'size': 20}, title_fontsize = 20)
fig.legend(handles[6:8], labels[6:8], bbox_to_anchor=(0.5, 0.94),
           prop={'size': 20}, title_fontsize = 20)
fig.legend(handles[8:], labels[8:], bbox_to_anchor=(0.6, 0.94),
           prop={'size': 20}, title_fontsize = 20)
axs[0].text(
    0.01, 0.99,
    s = f'Classical LinUCB',
    transform=axs[0].transAxes,
    fontsize=30,
    fontweight='semibold',
    va='top', ha='left',
    color='dimgray',
    bbox=dict(facecolor='white', alpha=0.3, edgecolor='none', boxstyle='round,pad=0.3'))

for ax, title in zip(axs[1:], list_batch_belief):
    ax.text(
        0.01, 0.99,
        s = f'S-UCB-Belief $\ell$ = {str(title)}',
        transform=ax.transAxes,
        fontsize=30,
        fontweight='semibold',
        va='top', ha='left',
        color='dimgray',
        bbox=dict(facecolor='white', alpha=0.3, edgecolor='none', boxstyle='round,pad=0.3')
    )

fig.savefig(f"{PATH_PICS}/plot_performance_compare_with_legend.pdf", dpi=300, bbox_inches='tight')

plt.close("all")
plt.clf()
plt.cla()
del fig, ax, axs

### Export With best performance
fig, ax = plt.subplots(1, 1, figsize=(13, 12))

fig.subplots_adjust(
    left=0.06,
    right=0.98,
    bottom=0.06,
    top=0.88,
    wspace=0.06,
    hspace=0.05
)

dict_results_benchmark_compare = {}
for key_ in dict_results_benchmark:
    if 'C=0.0628' in key_:
        dict_results_benchmark_compare[key_] = dict_results_benchmark[key_]

for key_ in list_dict_belief_results[0]:
    if 'C=0.0005' in key_:
        dict_results_benchmark_compare[key_] = list_dict_belief_results[0][key_]

plot_performance_benchmark(dict_results_benchmark_compare, T_range=T_range, sample_size=len(list_seed_simulation),
                           keep_legend = False, ax=ax, sampling = 0.1)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles[:1], ['Linear Bandits --- C=0.0628'], bbox_to_anchor=(0.45, 0.94),
           prop={'size': 20}, title_fontsize = 20)
fig.legend(handles[1:], [r'S-UCB-Belief --- $\ell$ = 320 C=0.0005'], bbox_to_anchor=(0.95, 0.94),
           prop={'size': 20}, title_fontsize = 20)
ax.text(
    0.01, 0.99,
    s = f'Classical LinUCB vs S-UCB-Belief',
    transform=ax.transAxes,
    fontsize=30,
    fontweight='semibold',
    va='top', ha='left',
    color='dimgray',
    bbox=dict(facecolor='white', alpha=0.3, edgecolor='none', boxstyle='round,pad=0.3'))

fig.savefig(f"{PATH_PICS}/plot_performance_with_legend.pdf", dpi=300, bbox_inches='tight')

plt.close("all")
plt.clf()
plt.cla()
del fig, ax
