#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy.random as rnd
import numpy as np
from numpy.random import SeedSequence
import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())

#import bandit_main_classes.bandit_generator as bandit
import bandit_main_classes.initiate_bandit  as initiate_bandit
import bandit_main_classes.regret as Regret
import bandit_main_classes.arms as Arms


def init_thetas_and_seeds_for_bandit_objects(old_K, thetas, nrun):
    """
    Initialize seeds based on the user's input for the number of runs and thetas.
    If the value of 'thetas' is -1, 'old_K' will determine the actual number of arms.
    If 'thetas' is a list of real values, the number of arms is determined by the length of the list.
    """
    
    # first check the size and length of thetas aksed bu the user
    if isinstance(thetas, list) is False:
        try: 
            thetas = [float(thetas)]
        except:
            print('major issue on thetas init')
            return -1
    
    # generate a global list of seeds to share to all arms and bandit algo
    global_seed_sequence = SeedSequence()
    if thetas[0] == -1 or thetas[0] == -2:
        K=old_K 
    else:
        K = len(thetas)
    
    # return the seeds (more that the number of run and int to make sure that there is more for the algorithm)
    seeds = global_seed_sequence.generate_state(int((K+1)*(nrun + 1)))
  
    return K, thetas, seeds


def init_bandit_from_argparse(args, seeds):
    """
    Initiate the bandit object from the argparse object and the seeds.
    """
    if len(seeds) < 3:
        print('major issue the Bandit cant be initialized since it required at least three seed')
        return -1
    else:
        # initiate the method object that should handle the hyperparameters used for the bandit algorithms (not used yet for AIM or Thompson sampling)
        if args.methodname == 'klucb_plus_plus':
            method = initiate_bandit.initiate_method(args.methodname, args.K, [args.horizon])
        else:
            method = initiate_bandit.initiate_method(args.methodname, args.K)
        # initiate the bandit object
        Bandit_simu = Bandit.fromrawinfo(method, args.K, args.thetas, seeds, args.ditsrbandit)
        return Bandit_simu  

def init_method_from_argparse(args):
    """
    Initialize the method object based on the arguments provided through the argparse object.
    """
    # initiate the method object that should handle the hyperparameters used for the bandit algorithms (not used yet for AIM or Thompson sampling)
    if args.methodname == 'klucb_plus_plus':
        method = initiate_bandit.initiate_method(args.methodname, args.K, [args.horizon])
    else:
        method = initiate_bandit.initiate_method(args.methodname, args.K)
    return method

class Bandit:
    """
    Bandit object is the main object that will handle the bandit game
    It is constituted of the algorithm, the arms and the regret 
    It will be used for the simulation and the plotting
    
    Parameters
    ----------
    algo: An object of the class Algo_Bandit, responsible for selecting the next arms to play and maintaining its own memory of previous rewards.
    arms: A list of objects of the class Arms, where each arm has its own distribution and keeps track of its previous rewards. When an arm is pulled, it returns a reward (typically 0 or 1 for a Bernoulli distribution).
    regret:  An object of the class Regret, used to store and compute the regret of the bandit game.

    Methods
    -------
    make_one_step: Executes one step of the bandit game, involving querying the algorithm to determine the next arm to play. Afterward, the Bandit object updates the arms, the algorithm, and the regret.
    show: Prints the current parameters and status of the Bandit object by invoking the show method for the algorithm, regret, and arms.
    """
    def __init__(self, algo, arms, regret):
        """
        Initialize the Bandit object.

        Parameters:
        - algo: An object of the class Algo_Bandit, responsible for arm selection and maintaining memory of rewards.
        - arms: A list of objects of the class Arms, each representing an arm with its own distribution and reward history.
        - regret: An object of the class Regret, used for storing the regret of the bandit game.
        - NBarm: An integer representing the number of arms in the bandit game.
        """
        self.algo=algo
        self.arms=arms
        self.Nbarm=algo.Nbarm
        self.regret=regret
        if self.Nbarm != len(self.arms):
            print("major issue bad number of arm given in input")
    
    @classmethod   
    def fromrawinfo(cls, method, K, thetas, seeds, ditsrbandit):
        """
        Initiate the bandit object from the raw information given by the user

        Parameters
        ----------
        method: object of the class Method (used to store more specific parameterss given by method object but in practice it is not used for AIM or Thompson sampling)
        K: int number of arms
        thetas: list of float specifying the mean of the arm if bernoulli or the exact mean of the arm if exact if thetas[0] is -1 then thetas will be randomly generated by the last seed of the seeds list + 1
        seeds: list of int specifying the seeds used for the bandit game (the first seed is used for the bandit algo and the other seeds are used for the arms)
        """

        # prepare to generate thetas value if necessary
        simu_generator_thetas = rnd.default_rng(seed=seeds[-1]+1)
        # check if a random theta is already given or not by the user 
        # if thetas[0] == -1 then thetas will be randomly generated by the last seed of the seeds list + 1
        if thetas[0] == -1:
            thetas = []
            for index in range(K):
                thetav = 0
                while ((thetav == 0) or (thetav == 1)):
                    thetav = simu_generator_thetas.random()
                thetas.append(thetav)
    

        # at this stage thetas is already a list and K and seeds should have the expected size
        
        # last seeds is used for the algo 
        simu_generator_algo = rnd.default_rng(seed=seeds[-1])
        method.generator = simu_generator_algo
        #initiate the algo object
        algotest = initiate_bandit.initiate_algo(method, K)
        #initiate the regret object giving the thetas in input 
        regret = Regret.Regret(np.array(thetas))
        #initiate the arms object giving the thetas and the seeds in input
        arms =[]
        for index in range(K):
            simu_generator_arm = rnd.default_rng(seed=seeds[index])
            arms.append(Arms.Arm(ditsrbandit,thetas[index], simu_generator_arm))
        #initiate the bandit object
        return cls(algotest, arms, regret)
        
                
    def make_one_step(self):
        """
        Make one step of the bandit game by asking the algorithm which arm to play next. Then, the bandit updates the arm, the algorithm, and the regret.
        """
        index=  self.algo.draw_index() #choose an index (the algo can still be in the inititalization phase)
        reward = self.arms[index].sample() # draw the chose arm
        self.algo.update_mem(index, reward) #update the algorithm with the new result
        self.regret.update(index) #update regret

        # the reward is returned if the user want to have acess to the selected arm and the output
        return index, reward

    
    def show(self):
        """
        Display the current parameters of the bandit object by invoking the 'show' method for the algorithm, the regret, and the arms
        """
        print('')
        print('Bandit object at time: ' + str(self.algo.time))
        print('')
        print(self.algo.show())
        print('')
        print(self.regret.show())
        print('')
        print(self.arms[0].show())
        print('')
        print(self.arms[1].show())
        print('')
               
