#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())

import numpy as np

#import tools.tools as tls
import tools.tools_entropy_object as tls_entropy
from bandit_main_classes.algo_bandit_general import Algo_Bandit
import entropy.analytic_entropy as Analytic_Entropy


class Algo_Entropy_Increment(Algo_Bandit):

    """
    Class of a bandit algorithm which use the entropy increment as a criterion to select the next arm

    The entropy is approximated using an analytical method, which enhances computational efficiency and tractability.
    Of note the entropy increment is computed along the better empirical arm and the worse empirical arms. 
    It will use several analytic expressions for the entropy depending of the increment evaluation to improve computational speed and simplicty.
    
    To provide a tunable algorithm the exact expressions used for the entropy increment can be changed by generating a child class with diffrent functions for
    self.function, self.ptail_fct,self.max_increment_fct

    description of the algorithm:
    it will use several functions to compute the entropy increment depending if its along the better empirical arm or a worse empirical arm

    Objects:

    self.mem: Stores the arm count and the sum of the rewards for each arm.
    self.increment_function: Computes the entropy increment along the selected worse empirical arm.
    self.compute_max_increment: Computes the entropy increment along the better empirical arm.
    self.get_ptail_function: Computes the probability tail associated to the selected worse empirical arm.
    self.function: The selected function to compute the entropy approximation when evaluating the entropy increment along the selected worse emprical arm.
    self.ptail_fct: The selected function to compute the probability tail when evaluating the entropy increment along the better empirical arm.
    self.max_increment_fct: The selected function to compute the entropy approximation when evaluating the entropy increment along the better empirical arm.
    
    Methods:

    self.initiate: Selects the arm at the initialization phase.
    self.draw_index: Selects the arm to be pulled at the next step.
    self.get_max_theta_index: Return sthe index of the better empirical arm.
    self.choose_index: Selects the arm to be pulled at the next step after the initialization phase.
    self.reset: Reset the algorithm for an another round.
    self.derive_increment_of_max: Compute the entropy increment along the better empirical arm.
    

    """


    def __init__(self, method, K):
        """
        Initiate the algorithm

        It will generate the arms memory and the entropy increment functions.
    
        """
        super().__init__(method, K)

        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm)}
        self.increment_function= lambda fct1, j1, j2, j3, j4  : tls_entropy.compute_increment_along_second_index(fct1, j1, j2, j3, j4)
        #self.display_increment_function= lambda fct1, j1, j2, j3, j4, id  : tls_entropy.compute_and_display_increment_of_one_index(fct1, j1, j2, j3, j4, increment_index=id)

        self.get_ptail_function = lambda fct1, j1, j2, j3, j4  : tls_entropy.compute_ptail(fct1, j1, j2, j3, j4)
        self.compute_max_increment = lambda fct1, j1, j2, pmaxv : tls_entropy.compute_armax_increment(fct1, j1, j2, pmaxv)
        
        self.function=None
        self.ptail_fct=None
        self.max_increment_fct=None


    def initiate(self, index):
        """
        Initiate the algorithm by selecting the arm at the initialization phase given by index in input.
        """
        return index 
    
    def draw_index(self):
        """
        Return the index of the arm to be pulled at the next step

        If the algorithm is still in the initialization phase it will return the arm given by algo.time
        after it will return the arm selected by self.choose_index
        """
        if self.time < self.Nbarm :
            return self.initiate(self.time%self.Nbarm)
        else:
            return self.choose_index() 


    def get_max_theta_index(self):
        """
        Return the index of the better empirical arm (along the mean of the posterior associated to each arm).
        Let us remind that the mean of the posterior is given by theta = (1+wins)/(wins + loss +2).
        
        Of note if two thetas have the same value the one with the less number of drawns is returned.
        It will allow to use an entropy approximations with the less reliable arm to make sure that the others worse empirical arms are still no preferable to be chosen,
        but this shouldn't be crucial.
        """

        thetamax = -1
        Nmax = -1
        indexmax = 0
        #iterate over all the arms
        for index, Nc in enumerate(self.mem['count']):
            if Nc >0 :
                if  thetamax < (self.mem['cumsum'][index]+1)/(Nc+2):
                    thetamax = (self.mem['cumsum'][index]+1)/(Nc+2)
                    indexmax = index
                    Nmax =Nc
                #case where thetas are equals
                elif thetamax == (self.mem['cumsum'][index]+1)/(Nc+2):
                    if  Nc < Nmax:
                        indexmax = index
                        Nmax =Nc
                    elif Nc == Nmax :
                        if self.seeded_generator.integers(low=0, high=2, size=1)[0] > 0.5:
                            indexmax = index
                            Nmax =Nc
                      
        return thetamax, indexmax
        
  
    def choose_index(self):
        """
        It will return the index of the arm to be pulled at the next step after the initialization phase.

        First the method derive the increment along the better empirical arm,
        then it will derive the increment along each worse suboptimal arm and finally keep the one with the biggest variation.
        
        """
        # select the better empirical arm
        thetamax, indexmax = self.get_max_theta_index()
        
        #compute the increment along the better empirical arm
        next_future_increment = self.derive_increment_of_max(indexmax)

        #compute the increment along each worse empirical arm
        next_future_index = -1
        for index in range(self.Nbarm):
            if index != indexmax:
                #compute the increment along the worse empirical arm
                #  as required self.mem['cumsum'][indexmax] is the number of wins and self.mem['count'][indexmax]- self.mem['cumsum'][ indexmax] is the number of losses of the better empirical arm
                # and self.mem['cumsum'][index] is the number of wins and self.mem['count'][index]- self.mem['cumsum'][index] is the number of losses of the selected worse empirical arm
                # the two first index have to be the one of the better empirical arm to compute the increment along the second index
                info_future_index, increment = self.increment_function(self.function, self.mem['cumsum'][indexmax], self.mem['count'][indexmax]- self.mem['cumsum'][ indexmax], self.mem['cumsum'][index], self.mem['count'][index]- self.mem['cumsum'][index])
                #if info_future_index != 1 it means that domething where wrong in the computation of the increment then this arm should not be considered but this should not happen
                if info_future_index == 1 :
                    #check if the increment is bigger than the previous one
                    if next_future_increment < increment:
                        next_future_increment = increment
                        next_future_index  = index
        # if next_future_index == -1 it means that the better empirical arm is the best one
        if next_future_index == -1:
            return indexmax
        else:
            return next_future_index
    
    def reset(self):
        """
        Reset the algorithm for an another simulation round by reseting the memory of the algorithm.
        """
        super().reset()
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm)}


    def derive_increment_of_max(self, indexmax):
        """
        Get the entropy increment along the better empirical arm.
        """
        # first compute the sum of the probability tail along each worse empirical arm
        # as required self.mem['cumsum'][indexmax] is the number of wins and self.mem['count'][indexmax]- self.mem['cumsum'][ indexmax] is the number of losses of the better empirical arm
        # and self.mem['cumsum'][index] is the number of wins and self.mem['count'][index]- self.mem['cumsum'][index] is the number of losses of the selected worse empirical arm
        # the two first index have to be the one of the better empirical arm to compute the increment along the second index
        ptail_sum = 0
        for index in range(self.Nbarm):
            if index != indexmax:
                ptail_sum += self.get_ptail_function(self.ptail_fct, self.mem['cumsum'][indexmax], self.mem['count'][indexmax]- self.mem['cumsum'][ indexmax], self.mem['cumsum'][index], self.mem['count'][index]- self.mem['cumsum'][index])
        
        #compute the increment along the better empirical arm using the sum of the probability tail along each worse empirical arm 
        return self.compute_max_increment(self.max_increment_fct, self.mem['cumsum'][indexmax], self.mem['count'][indexmax]- self.mem['cumsum'][ indexmax], ptail_sum)

class Algo_AIM(Algo_Entropy_Increment):

    """
    This class is a child class of Algo_Entropy_Increment which use the analytic approximation of the entropy increment to get a faster algorithm.
    The entropy increment is computed along the better empirical arm and the worse empirical arms.
    It will use several analytic expressions for the entropy depending of the increment evaluation to improve computational speed and simplicty.

    Please refer to Analytic_Entropy.py for more details about the analytic expressions used to compute the entropy increment.
    """
    def __init__(self, method, K):
        super().__init__(method, K) 

        self.function = lambda x,y,v,w : Analytic_Entropy.analytic_entropy(x, y, v, w) 
        self.ptail_fct = lambda x,y,v,w : Analytic_Entropy.analytic_entropy_ptail(x, y, v, w)
        self.max_increment_fct = lambda x,y : Analytic_Entropy.analytic_entropy_armax(x, y)


