# coding: utf-8
import os

import numpy as np
import pandas as pd
from joblib import load, dump

from src.Linear_Bandits_Belief import LinearBanditsBelief
from src.utils import INIT_DISTRIBUTION

####################################################################################
####################################################################################
############################# AREA OF INPUT PARAMETERS #############################
####################################################################################
####################################################################################
##### Environment Parameters
PATH = "." # Root directory, should be the same path this "README.md" file locates
PATH_DATA = f"{PATH}/data" # Path for data
PATH_MODELS = f"{PATH}/models"  # Path for models

##### Parameters for Bandits
number_rounds = 100000 # T, use 100000 to reproduce the results
list_ucb_multipler = np.round(np.geomspace(0.0005, 0.25, 10), 4)
random_seed = 1987 #  To reproduce the results, run 10 times with random_seed from 1986 to 1995
list_batch = [320, 2200, 5600]
with_union_bound = False
prefix_union_bound = 'ub' if with_union_bound else 'no_ub'

####################################################################################
####################################################################################
############# Create Output Path, Load the data and Model###############
####################################################################################
####################################################################################
##### Load Data and Model
dt_reward_belief = pd.read_parquet(f"{PATH_MODELS}/random_seed_{random_seed}/dt_reward_belief.pq")
dt_reward_dummy = pd.read_parquet(f"{PATH_MODELS}/random_seed_{random_seed}/dt_reward_dummy.pq")
dict_kpi_rewards = load(f"{PATH_MODELS}/random_seed_{random_seed}/dict_kpi_rewards.pkl")
Belief_simulator = load(f"{PATH_MODELS}/random_seed_{random_seed}/belief_simulator.pkl")
Spectral_direct_estimator = load(f"{PATH_MODELS}/random_seed_{random_seed}/spectral_direct_estimator.pkl")
dict_hyper = load(f"{PATH_MODELS}/dict_hyper.pkl")
print(dict_hyper)

lambda_no_belief = dict_hyper["lambda_no_belief"]
lambda_belief = dict_hyper["lambda_belief"]

##### Create the folder for the output model
if os.path.isdir(f"{PATH_MODELS}/random_seed_{random_seed}"):
    pass
else:
    os.makedirs(f"{PATH_MODELS}/random_seed_{random_seed}")

####################################################################################
####################################################################################
######### Simulate The results for the linear bandits with belief ##############
####################################################################################
####################################################################################
for batch_len_ in list_batch:
    print(batch_len_)
    for UCB_multiply_ in list_ucb_multipler:
        UCB_multiply_str_ = "0" + str(UCB_multiply_)[2:]
        print(UCB_multiply_str_)
        if os.path.isfile(
                f"{PATH_MODELS}/random_seed_{random_seed}/linear_bandits_Belief_B{batch_len_}_C{UCB_multiply_str_}_{prefix_union_bound}.pkl"):
            print('skip')
            continue
        dict_linear_bandits_ = \
            {"seed": random_seed, "number_rounds": number_rounds, "lmd": lambda_belief,
             "vec_init": INIT_DISTRIBUTION, "batch_lengh": batch_len_, "verbose": False,
             "UCB_multiply": UCB_multiply_, "union_bound": with_union_bound, "hot_start": 100}

        obj_linear_bandits_ = LinearBanditsBelief(dict_linear_bandits_, dt_reward_belief, dt_reward_dummy,
                                                  Spectral_direct_estimator)
        obj_linear_bandits_.run_simulation()
        dump(obj_linear_bandits_,
             f"{PATH_MODELS}/random_seed_{random_seed}/linear_bandits_Belief_B{batch_len_}_C{UCB_multiply_str_}_{prefix_union_bound}.pkl")
        print("------------------------------")
    print('---------------------------------------------')
    print('---------------------------------------------')