#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import argparse

from multiprocessing import Pool
import numpy.random as rnd

import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())

import tools.tools_to_save as tool_save
import tools.tools as tool
import bandit_main_classes.bandit_game as Bandit
from bandit_main_classes.bandit_game import init_thetas_and_seeds_for_bandit_objects
from tools.initiate_algo import init_bandit_algo_argpasre



def play_one_bandit(numsimu, seed, thetas, ditsrbandit, Nbsteps,method, K):
    """
    Perform a full simulation of a bandit game for a specified number of steps.

    Parameters:
    - 'numsimu': An integer representing the number of simulations.
    - 'seed': A list of integers used as seeds for the simulation.
    - 'thetas': A list of floating-point numbers representing the means of the arms. If set to -1, the arms will be generated randomly on a uniform prior.
    - 'distrbandit': An integer indicating the type of reward distribution used for the arms.
    - 'Nbsteps': An integer representing the number of steps in the simulation.
    - 'method': An object representing the method used for the bandit algorithm (it defines the bandit algorithm employed in the run).
    - 'K': An integer representing the number of arms.

    Output:
    - 'regret_curves': A list of floating-point values representing the regret obtained during the simulation.
    - 'high_regret': An integer (1 or 0) that indicates whether the bandit algorithm encountered high regret (1 if more than 1/3 of the time the algorithm did not select the optimal arm). This flag is used for debugging or pedagogical purposes.
    
    """


    thetas_simu = thetas
    # generate thetas if needed
    simu_generator_thetas_simu = rnd.default_rng(seed=seed[-1]+1)

    # case thetas have to be generated randomly
    if thetas_simu[0] == -1:
        thetas_simu = []
        #print(thetas_simu)
        for index in range(K):
            thetav = 0
            while ((thetav == 0) or (thetav == 1)):
                thetav = simu_generator_thetas_simu.random()
            thetas_simu.append(thetav)
    
    #print(thetas_simu)
    Bandit_simu = Bandit.Bandit.fromrawinfo(method, K, thetas_simu, seed, ditsrbandit)
   
    #print(numsimu)
    # initiate regret
    regret_cuvres = np.zeros(Nbsteps)

    timeindex=0
    # make full simulation and save the regret
    while Bandit_simu.algo.time < Nbsteps :

            Bandit_simu.make_one_step() 
            regret_cuvres[timeindex] = Bandit_simu.regret.regret
            timeindex += 1

    if Bandit_simu.regret.Nwrong > int(Nbsteps/3.0):
        high_regret = 1
    else:
        high_regret = 0
   
    return [regret_cuvres, high_regret]



def log_result(result):
    """
    Log the result of the simulation in the global variable result_list
    """
    # This is called whenever foo_pool(i) returns a result.
    # result_list is modified only by the main process, not the pool workers.
    result_list.append(result[0])
    Nhigh_regret_tot.append(result[1])


def init_simu_argparse():
    """
    Initialize the argparse settings for the simulation.

    Parameters:
    - 'Nbsteps': An integer representing the number of steps in the simulation.
    - 'nprocess': An integer indicating the number of processes used for parallelization.
    - 'nrun': An integer representing the number of runs of the bandit algorithm.
    - 'parallel_option': A boolean flag. If True, the simulation is parallelized. If False, it can be used for debugging purposes.
    - 'timescale': A string representing the type of timescale used for regret computation.
    - 'Ntime': An integer indicating the size of the array to be saved for the regret.
    
    """

    parser = argparse.ArgumentParser(description='main in order to generate simulation parameters ')
    parser.add_argument('--Nbsteps', type=int, help='Number of steps made by each bandit', required=True)
    parser.add_argument('--nprocess', type=int, default=20, help='Number of processes used if parrallel option is True', required=True)
    parser.add_argument('--nrun', type=int, help='Number of bandit simulation in total', required=True)
    parser.add_argument('--parralel_option',  default=False, help='if true the simulation is parallelized (simple parallelization)', action='store_true')
    parser.add_argument('--timescale', default='all', const='all', nargs='?', choices=['all', 'linspace', 'logspace'], help = 'handle the can be logspace linspace or all it will generate a time scale for the which the regret will be saved')
    parser.add_argument('--Ntime', type=int, default=200, help = 'used if timescale options is linspace or logspace, it will control the size of the time array used for save the regret')
    return parser

if __name__ == '__main__' :
    

    #initiate the bandit algo parameters
    parser = init_bandit_algo_argpasre()
    #possible methodsname are
    #aim
    #thompson
    args = parser.parse_args(['--thetas=-1', '--K=8','--methodname=aim'])

    #initiate the simulation parameters
    parser_simu = init_simu_argparse()
    args_simu = parser_simu.parse_args(['--Nbsteps=400', '--nrun=200','--parralel_option', '--nprocess=25', '--timescale=all', '--Ntime=20'])
    
    #total duration for bandit
    Nbsteps= args_simu.Nbsteps 
    #nb processor
    nprocess=args_simu.nprocess 
    #Nrumber of run for bandit
    nrun= args_simu.nrun
    #if parallel version of the code is used
    parralel_option = args_simu.parralel_option
    #parralel_option = False

    
    K, thetas, seeds = init_thetas_and_seeds_for_bandit_objects(args.K, args.thetas, nrun)
    args.K =K
    args.thetas =thetas 
    method = Bandit.init_method_from_argparse(args)
    
    result_list=[]
    Nhigh_regret_tot = []

    #if parallel version of the code is used
    if parralel_option == True :
        with Pool(processes=nprocess) as pool:
            for j in range(nrun):
 
                if j==0 or j==nrun-1:
                    time_option=False
                else:
                    time_option=False
                # making sure than the seeds are different for each run
                pool.apply_async(play_one_bandit, args=(j, seeds[(args.K+1)*j:(args.K+1)*(j + 1)], args.thetas, args.ditsrbandit,Nbsteps,method, args.K), callback = log_result)  
      
            pool.close()
            pool.join()
    else:
         for j in range(nrun):

            if j==0 or j==nrun-1:
                time_option=False
            else:
                time_option=False
            # making sure than the seeds are different for each run
            result = play_one_bandit(j, seeds[(args.K+1)*j:(args.K+1)*(j + 1)], args.thetas, args.ditsrbandit,Nbsteps,method, args.K)
            result_list.append(result[0]) 
            Nhigh_regret_tot.append(result[1]) 
 
    # convert the list of regret into a numpy array
    result_list = np.array(result_list)
    Nhigh_regret_tot = np.array(Nhigh_regret_tot)
    
    # get the mean and std of the regret
    mean_result = np.mean(result_list, axis=0) 
    std_result = tool.get_std(result_list) 

    # gener the filename for the simulation
    filename = tool_save.generate_filename('', args.methodname,Nbsteps, args.thetas, args.K,suffix='.json',nrun=nrun)

    # save the result of the simulation
    if args_simu.timescale == 'all':
        #save the result of the simulation for all the times steps
        timescale = np.arange(0, len(mean_result))+1
        tool_save.save_json_file(filename, ['mean_result', mean_result], ['std_result', std_result],['timescale', timescale], ['methodname',args.methodname],
                                     ['thetas', args.thetas],['Nbsteps', Nbsteps], ['seeds', seeds], ['K', args.K])
    elif args_simu.timescale == 'linspace':
        #save the result of the simulation for a linspace of Ntime steps
        indexscale =  np.int32(np.linspace(0,len(mean_result)-1, num=args_simu.Ntime))
        mean_result = mean_result[indexscale]
        std_result = std_result[indexscale]
        timescale = indexscale + 1
        tool_save.save_json_file(filename, ['mean_result', mean_result], ['std_result', std_result], ['timescale', timescale],  ['methodname',args.methodname],
                                     ['thetas', args.thetas],['theta2', args.theta2],['Nbsteps', Nbsteps], ['seeds', seeds], ['K', args.K])
    
    elif args_simu.timescale == 'logspace':
         #save the result of the simulation for a logspace of Ntime steps
        indexscale =  np.int32(np.logspace(np.log10(1), np.log10(len(mean_result)-1), num=args_simu.Ntime))
        mean_result = mean_result[indexscale]
        std_result = std_result[indexscale]
        timescale = indexscale + 1
        tool_save.save_json_file(filename, ['mean_result', mean_result], ['std_result', std_result], ['timescale', timescale],  ['methodname',args.methodname],
                                      ['thetas', args.thetas],['Nbsteps', Nbsteps], ['seeds', seeds], ['K', args.K])
    else:
        print('weird case timscale option hasnt been recognized')

        timescale = np.arange(0, len(mean_result))+1
        tool_save.save_json_file(filename, ['mean_result', mean_result], ['std_result', std_result],['timescale', timescale], ['methodname',args.methodname],
                                      ['thetas', args.thetas],['Nbsteps', Nbsteps],  ['seeds', seeds], ['K', args.K])

  
   