#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import numpy as np

import matplotlib.pyplot as plt
import matplotlib.cm as cm

import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())

import bandit_main_classes.bandit_game as Bandit
from bandit_main_classes.bandit_game import init_thetas_and_seeds_for_bandit_objects
from tools.initiate_algo import init_bandit_algo_argpasre
from tools.tools import get_label


def main():
    """
    Please laucn this function to test the bandit game and observe the result of one ore several simulations.
    It will plot the empirical mean of the arms evolution and the regret evolution along the time.
    
    """
    #initiate the simulation parameters
    #Of note thetas -1 means random thetas that will be generated using init_thetas_and_seeds_for_bandit_objects function
    #Possible methodname are :
    # aim
    # thomposn 
    parser = init_bandit_algo_argpasre()
    args = parser.parse_args(['--thetas=-1', '--K=8',
    '--methodname=klucb'])
    
    #Nb of steps of the simulation
    Nbsteps= 400

    #initiate the arms mean rewards and seeds for the simulation
    K, thetas, seeds = init_thetas_and_seeds_for_bandit_objects( args.K,args.thetas, nrun=1)

    args.K=K
    args.thetas=thetas
    args.horizon = Nbsteps

    #initiate the full bandit game object. It encompassess the arms, the algo and the regret
    Bandit_simu = Bandit.init_bandit_from_argparse(args, seeds)
    
    #initiate the stroing list for the curves, it is only for demonstration purpose
    #for a wiser use and a longer simulation, it is better to store the regret at some specific times
    arms_curves = [{'time': [], 'average': [], 'std': [] } for inbdex in range(K)]
    regret_cuvres = {'time': [], 'regret': []}

    #start the simulation and store the results of each step 
    while Bandit_simu.algo.time < Nbsteps :
            #print(timeindex)
            # ask to the algo to choose the next arm
            # then the bandit will update the arm, the algo and the regret
            index, reward = Bandit_simu.make_one_step()
            #not optimal is 1 if the arm selected is not the optimal one but is not shared to the bandit algo

            #Bandit_simu.algo.show()

            #if Bandit_simu.algo.time % 200 == 0:
            #    print('')
            #    Bandit_simu.algo.display_next_draw()
            #    print('')

            #store the results of the step using algo time
            arms_curves[index]['time'].append(Bandit_simu.algo.time)
            arms_curves[index]['average'].append(Bandit_simu.arms[index].get_average())
            arms_curves[index]['std'].append(Bandit_simu.arms[index].get_std())

            #only store regret when the arm selected is not the optimal one
           
            #regret_cuvres['time'].append(algo.time)
            regret_cuvres['time'].append(Bandit_simu.algo.time)
            regret_cuvres['regret'].append(Bandit_simu.regret.regret)

                
    #print the seeds used for the simulation if one aims to reproduce the results
    print('seeds used for the simulation: ', seeds)

    #plot the results

    startstd = 10    
    nbsteps_color = float(1/(K+1))
    
    #plot the arms empirical mean 
    plt.figure()   
    for index in range(K):
        t1 = np.array(arms_curves[index]['time'])
        y1 = np.array(arms_curves[index]['average'])
        stdy1 = np.array(arms_curves[index]['std'])
        color1 = cm.rainbow((index+1)*nbsteps_color)
        label= r'$ \mu_{' + str(index+1) + '}=' + str(round(Bandit_simu.arms[index].param, 2)) + '$'
        plt.plot( t1, y1, linestyle='-', label=label, color =color1)
        plt.fill_between(t1[startstd:], y1[startstd:] - stdy1[startstd:], y1[startstd:] + stdy1[startstd:], alpha=0.2, label='_nolabel', edgecolor=color1, facecolor=color1, linewidth=2, linestyle='dashdot', antialiased=True)
    
    plt.ylabel(r'$\theta$')
    plt.xlabel(r'$t$')
    plt.legend(loc='best', fontsize=12)
    plt.show()
    
    
    #plot the regret evolution
    plt.figure()
    tr = np.array(regret_cuvres['time'])
    yr = np.array(np.array(regret_cuvres['regret']))
    plt.plot( tr, yr, linestyle='-', label=get_label(args.methodname), color = 'blue')
    plt.ylabel(r'$Regret$')
    plt.xlabel(r'$t$')
    plt.legend(loc='best', fontsize=12)
    plt.show()


if __name__ == '__main__' :
    main()
    