#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import numpy as np
import types

import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())

#import tools.tools as tls
import tools.tools as tls
import math

# general class of the bandit algorithm which should be     
class Algo_Bandit : 
    """
    General Class of the Bandit Algorithm

    Description:
    This class serves as a general framework for a bandit algorithm, providing the minimal requirements for the algorithm object and its interactions with other elements of the bandit game, specifically the arms.

    Parameters:
    -----------
    - method: an object of the Method class (used to store more specific parameters given by the method object, but not currently utilized in practice).
    - K: an integer, the number of arms.
    - time: an integer, the current time of the bandit game.
    - seeded_generator: a NumPy random generator.

    Methods:
    --------
    - update_mem: Updates the memory of the bandit algorithm object when an arm has been pulled.
    - draw_index: Returns the index of the arm to be pulled at the next step.
    - choose_index: Returns the index of the arm to be pulled at the next step after the initialization phase.
    - initiate: Returns the index of the arm to be pulled at the next step during the initialization phase.
    - reset: Resets the bandit algorithm object for a new simulation.
    - show: Prints the current parameters of the bandit algorithm object.
    - get_max_time: Returns the maximum number of draws of an arm and its index.
    - get_theta: Returns the empirical reward of an arm.
    - get_max_theta_index: Returns the empirical reward of the best arm and its index.   
    
    """
    def __init__(self, method, K):
        """
        Initializes the bandit object with the method object and the number of arms.

        Parameters:
        -----------
        method: object of the class Method (used to store specific parameters given by method object).
        num_arms: int, the number of arms.
        seeded_generator: numpy random generator (if required for random number generation).
        """
        self.name=method.name
        #assert(K>=2)
        self.Nbarm=K #number of arms
        self.time=0
        self.seeded_generator=method.generator
        
    #memory shared by the bandit algorithm 
    def update_mem(self, arm, reward):
        """
        Updates the memory of the bandit algorithm object when an arm has been pulled.

        Parameters:
        -----------
        arm: int, index of the arm pulled.
        reward: int, reward of the arm pulled (can be a float for a Gaussian distribution).
        """

        self.mem['count'][arm] += 1 #number of selection of each arms
        self.mem['cumsum'][arm] += reward #sum of all the rewards
        self.mem['theta'][arm] = self.mem['cumsum'][arm] / self.mem['count'][arm]
        #update the algo time
        self.time += 1 #total time
    
    def draw_index(self):
        """
        Returns the index of the arm to be pulled at the next step.

        Notes:
        - There may be an initialization phase during which the algorithm selects each arm once.

        Returns:
        int: Index of the arm to be pulled at the next step.
        """
        return self.choose_index()# if there is no initialization phase select one arm
        
    def choose_index(self):# select one arm 
        """
        
        Returns the index of the arm to be pulled at the next step.
        
        """
        print('issue: should not be used each algo should have a new chosse index function')
        pass
    
    
    def intiate(self):
        """
        
        Initiates the algorithm by selecting a specific arm until the end of the initialization phase.
        
        """
        pass
    
    def reset(self): 
        """
        
        Resets the algorithm for an another round
        
        """
        self.time=0
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}
        pass
    
    
    def show(self, methods=False):
        """
        Prints the current parameters of the bandit algorithm object.

        Notes:
        - This method is used for debugging or pedagogical purposes.
        - By default, not all methods are printed.

        Usage:
        To display the current parameters, simply call this method.

        """
        print('bandit object with current arguments:')

        list_to_print = dir(self)
        list_to_print = [name for name in list_to_print if name[0:2] != '__']
        list_arguments = []
        list_methods = []
        
        for name in list_to_print :
            fullname = 'self.' + name
            current_value = eval(fullname)
            if isinstance(current_value, types.MethodType):
                list_methods.append( str(fullname) + ': ' + str(eval(fullname)))
            else:
                list_arguments.append(str(fullname) + ': ' + str(eval(fullname)))
                
                if str(fullname) =='self.mem':
                    current_thetas = []
                    for k in range(len(current_value['count'])):
                        if  current_value['count'][k] != 0:
                            current_thetas.append(current_value['cumsum'][k] / current_value['count'][k])
                        else:
                            current_thetas.append(-1)
                    list_arguments.append(str('currentthetas:') + str((current_thetas)))
        
        for cstring in list_arguments:
            print(cstring)
        if methods == True:
            print('bandit object with current methods:')
            for cstring in list_methods:
                print(cstring)
    
    
    def get_max_time(self):
        
        """
        Returns the maximum number of draws of an arm and its index.

        Returns:
        (int, int): A tuple containing the maximum number of draws and the index of the corresponding arm.

        """
        return (np.max(self.mem['count']), np.argmax(self.mem['count']))
    
    def get_theta(self, index):
        
        """
        Gets the empirical reward of an arm given its index.

        Parameters:
        -----------
        arm_index: int, the index of the arm for which the empirical reward is to be retrieved.

        Returns:
        float: The empirical reward of the specified arm.
        """
         
        return self.mem['theta'][index]
    
    def get_max_theta_index(self):

        """

        Returns the index of the better empirical arm (along the empirical mean).

        Notes:
        - The mean of the posterior is given by theta = (wins) / (wins + losses).
        - In case two thetas have the same value, the one with the fewer draws is returned.
        
        This ensures the use of an entropy approximation with the less reliable arm to maintain
        the preference for other worse empirical arms. However, this is not crucial.

        Returns:
        int: The index of the better empirical arm.

        """

        thetamax = -1
        Nmax = -1
        indexmax = 0
        #iterate over all the arms
        for index, Nc in enumerate(self.mem['count']):
            if Nc >0 :
                if  thetamax < self.mem['cumsum'][index]/Nc:
                    thetamax = self.mem['cumsum'][index]/Nc
                    indexmax = index
                    Nmax =Nc
                #case where thetas are equals
                elif thetamax == self.mem['cumsum'][index]/Nc:
                    if  Nc < Nmax:
                        indexmax = index
                        Nmax =Nc
                    elif Nc == Nmax :
                        if self.seeded_generator.integers(low=0, high=2, size=1)[0] > 0.5:
                            indexmax = index
                            Nmax =Nc
                      
        return thetamax, Nmax, indexmax
        
         
# Class Algo Thompson Sampling to make a comparison with AIM perfomances 
class Algo_Thompson(Algo_Bandit):
    """
    A class for Thompson Sampling algorithm used to compare performance with AIM.
    
    Description:
    This class is a particular case of the Algo_Bandit class where the index is computed using
    the Thompson Sampling method. It assumes a uniform prior on the arms' mean moslty in [0,1] for our test on Gaussian reward distribution.
    
    """
    
    def __init__(self,method,K):
        
        super().__init__(method, K)
        
        """
        Initializes the bandit object with the method object and the number of arms.
        
        Parameters:
        -----------
        method: object of the class Method (used to store specific parameters given by method object).
        num_arms: int, the number of arms.

        Notes:
        - This method also initiates the memory of the algorithm.
        """

        self.mem = {'count': np.zeros(K, dtype=int), 'cumsum': np.zeros(K), 'theta': np.zeros(K)}
   
    
    def draw_index(self):
        """
        Returns the index of the arm to be pulled at the next step.

        Notes:
        - For times less than the number of arms, each arm is pulled once.
        - After that, the index is computed using the Thompson Sampling method.

        Returns:
        int: Index of the arm to be pulled at the next step.

        """
        if self.time < self.Nbarm:
           return self.initiate()
        else:
           return self.choose_index()

    def initiate(self):
        """

        Initiates the algorithm by selecting each arm once until the end of the initialization phase.
        
        """
        return self.time
    
    def choose_index(self):

        """
        Returns the index of the arm to be pulled at the next step using the Thompson Sampling method.

        Description:
        This method returns the index of the arm with the highest theta drawn from the posterior distribution
        of each arm, computed using the Beta distribution and past rewards.

        Returns:
        int: Index of the arm to be pulled at the next step.

        """
        draws = np.array([self.seeded_generator.normal(loc=s/n, scale=1.0/math.sqrt(n)) for n, s in zip(self.mem['count'], self.mem['cumsum'])])
        #draws = np.array([self.seeded_generator.gaussian(s/n, ) for n, s in zip(self.mem['count'], self.mem['cumsum'])])
        # case where there is more than one arm with the same maximum value
        if  len(np.where(draws  - np.max(draws) == 0.0)[0]) >= 2:
            return(np.where(draws - np.max(draws) == 0.0)[0][self.seeded_generator.integers(len(np.where(draws - np.max(draws) == 0.0)[0]))])
        # return the index of the arm with the highest theta
        else: 
            return np.argmax(draws)
            
    def reset(self):
        """

        Resets the algorithm for another round.

        """
        
        super().reset()
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}
            


# Class Algo Thompson Sampling
class Algo_Thompson_Plus(Algo_Bandit):
    
    def __init__(self,method,K):
        super().__init__(method, K)
        self.mem = {'count': np.zeros(K, dtype=int), 'cumsum': np.zeros(K), 'theta': np.zeros(self.Nbarm)}
    # select randomly an index no initialization is required     
    
    def draw_index(self):
        if self.time < self.Nbarm:
           return self.initiate()
        else:
           return self.choose_index()
    
    #draw each arms once
    def initiate(self):
        return self.time
    
    def choose_index(self):
        #super().choose_index()
        #draws = np.array([self.seeded_generator.normal(loc=s/n,scale=self.std/math.sqrt(n)) if self.seeded_generator.integers(low=0, high=2) <0.5 (s/n) else for n, s in zip(self.mem['count'], self.mem['cumsum'])])
        draws = np.array([self.seeded_generator.normal(loc=s/n, scale=1.0/math.sqrt(n)) if self.seeded_generator.integers(low=0, high=2) < 0.5 else s/n for n, s in zip(self.mem['count'], self.mem['cumsum'])])
        if  len(np.where(draws  - np.max(draws) == 0.0)[0]) >= 2:
            return(np.where(draws - np.max(draws) == 0.0)[0][self.seeded_generator.integers(len(np.where(draws - np.max(draws) == 0.0)[0]))])
        else: 
            return np.argmax(draws )
    def reset(self):
        super().reset()
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}


class Algo_MED(Algo_Bandit):
    def __init__(self,method,K):
        super().__init__(method, K)
        self.mem = {'count': np.zeros(K, dtype=int), 'cumsum': np.zeros(K), 'theta': np.zeros(self.Nbarm)}
    # select randomly an index no initialization is required   

    def draw_index(self):
        if self.time < self.Nbarm:
           return self.initiate()
        else:
           return self.choose_index()
        
    def initiate(self):
        return self.time

    def choose_index(self):
        # thetas 
        thetas = np.array([s/n for n, s in zip(self.mem['count'], self.mem['cumsum'])])
        thetamax = np.max(thetas)
        energy = np.exp(-self.mem['count']*(thetas - thetamax)**2/(2))
        probas = energy/np.sum(energy)

        return self.seeded_generator.choice(self.Nbarm, p=probas)


class Algo_KLUCB_Plus_Plus(Algo_Bandit):
    
    def __init__(self, method, K):
        super().__init__(method, K)
        self.horizon = method.horizon
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}
    
    def draw_index(self):
        #if self.time <Nbefore_large_step :
            if self.time < self.Nbarm:
               return self.initiate()
            else:
               return self.choose_index()
        #else:
        #     return self.draw_large_step()
    
    #draw each arms once
    def initiate(self):
        return self.time
    
    def choose_index(self):
            #compute each confidence interval using ucb1 normal before selecting the best one
        
            empirical_means = np.array([s/n if n>0 else 0. for n, s in zip(self.mem['count'], self.mem['cumsum'])])
            ginside = np.max(np.array([np.log(self.horizon/(self.Nbarm*self.mem['count'])),np.zeros(self.Nbarm)]), axis=0)**2+1
            radii = np.sqrt(2*np.max(np.array([np.log( (self.horizon/(self.Nbarm*self.mem['count']))*ginside), np.zeros(self.Nbarm)]), axis=0)/self.mem['count'])
            ucb = empirical_means + radii
            if  len(np.where(ucb - np.max(ucb) == 0.0)[0]) >= 2:
                return np.where(ucb - np.max(ucb) == 0.0)[0][self.seeded_generator.integers(len(np.where(ucb - np.max(ucb) == 0.0)[0]))]
            else: 
                return np.argmax(ucb)
    
    def reset(self):
        super().reset()
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}



class Algo_KLUCB(Algo_Bandit):

    def __init__(self, method, K):
        super().__init__(method,K)
        self.c = 0.0000
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}


    def draw_index(self):
        #if self.time <Nbefore_large_step :
            if self.time < self.Nbarm:
               return self.initiate()
            else:
               return self.choose_index()
        #else:
        #     return self.draw_large_step()
    
    #draw each arms once
    def initiate(self):
        return self.time
    
    
    def choose_index(self):
            #compute each confidence interval using ucb1 normal before selecting the best one
        
            empirical_means = np.array([s/n if n>0 else 0. for n, s in zip(self.mem['count'], self.mem['cumsum'])])
            ginside = np.log(self.mem['count']) 
            #+ self.c*np.log(np.log(self.mem['count']))
                             
            radii = np.sqrt(2*ginside/self.mem['count'])
            ucb = empirical_means + radii
            if  len(np.where(ucb - np.max(ucb) == 0.0)[0]) >= 2:
                return np.where(ucb - np.max(ucb) == 0.0)[0][self.seeded_generator.integers(len(np.where(ucb - np.max(ucb) == 0.0)[0]))]
            else: 
                return np.argmax(ucb)
        

    def reset(self):
        super().reset()
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}