#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())

import numpy as np
from scipy.special import betainc, erfc, erf
import math
#import tools.tools as tls

from bandit_main_classes.algo_bandit_general import Algo_Bandit



class Algo_Entropy_Increment(Algo_Bandit):

    """
    Class of a bandit algorithm which use the entropy increment as a criterion to select the next arm.
    Here for Gaussian rewards

    The entropy is approximated using an analytical method, which enhances computational efficiency and tractability.
    Of note the entropy increment is computed along the better empirical arm and the worse empirical arms. 
    It will use several analytic expressions for the entropy depending of the increment evaluation to improve computational speed and simplicty.
    
    To provide a tunable algorithm the exact expressions used for the entropy increment can be changed by generating a child class with diffrent functions for
    self.function, self.ptail_fct,self.max_increment_fct

    description of the algorithm:
    it will use several functions to compute the entropy increment depending if its along the better empirical arm or a worse empirical arm

    Objects:

    self.mem: Stores the arm count and the sum of the rewards for each arm.

    
    Methods:

    self.initiate: Selects the arm at the initialization phase.
    self.draw_index: Selects the arm to be pulled at the next step.
    self.get_max_theta_index: Return sthe index of the better empirical arm.
    self.choose_index: Selects the arm to be pulled at the next step after the initialization phase.
    self.reset: Reset the algorithm for an another round.


    """


    def __init__(self, method, K):
        """
        Initiate the algorithm

        It will generate the arms memory and the entropy increment functions.
    
        """
        super().__init__(method, K)

        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}
      

    def initiate(self, index):
        """
        Initiate the algorithm by selecting the arm at the initialization phase given by index in input.
        """
        return index 
    
    def draw_index(self):
        """
        Return the index of the arm to be pulled at the next step

        If the algorithm is still in the initialization phase it will return the arm given by algo.time
        after it will return the arm selected by self.choose_index
        """
        if self.time < self.Nbarm :
            return self.initiate(self.time%self.Nbarm)
        else:
            return self.choose_index() 


    def get_max_theta_index(self):
        """
        Return the index of the better empirical arm (along the mean of the posterior associated to each arm).
        Let us remind that the mean of the posterior is given by theta = (wins)/(wins + loss).
        
        Of note if two thetas have the same value the one with the less number of drawns is returned.
        It will allow to use an entropy approximations with the less reliable arm to make sure that the others worse empirical arms are still no preferable to be chosen,
        but this shouldn't be crucial.
        """

        thetamax = -1
        Nmax = -1
        indexmax = 0
        #iterate over all the arms
        for index, Nc in enumerate(self.mem['count']):
            if Nc >0 :
                if  thetamax < (self.mem['cumsum'][index])/(Nc):
                    thetamax = (self.mem['cumsum'][index])/(Nc)
                    indexmax = index
                    Nmax =Nc
                #case where thetas are equals
                elif thetamax == (self.mem['cumsum'][index])/(Nc):
                    if  Nc < Nmax:
                        indexmax = index
                        Nmax =Nc
                    elif Nc == Nmax :
                        if self.seeded_generator.integers(low=0, high=2, size=1)[0] > 0.5:
                            indexmax = index
                            Nmax =Nc
                      
        return thetamax, Nmax, indexmax
        
  
    def choose_index(self):
        """
        It will return the index of the arm to be pulled at the next step after the initialization phase.
        First it will verify that the best empirical arm has been more drawn than the others.
        If not it will return the best empirical arm to draw.

        Then the method compate the increment along the better empirical arm to each of all the worse empirical arms.  
        then it will return the one with the biggest infomation gain.
        
        """
        # select the best empirical arm
        thetamax, Nmax, indexmax = self.get_max_theta_index()

        #first check that the best empirical arm has been more drawn than the others
        for index in range(self.Nbarm):
            if self.mem['count'][index] > Nmax:
                return indexmax
        
        #compute the increment difference along each worse empirical arm and the best empirical arm
        next_future_index = -1
        current_best_increment = -1

        # fisrt compute the part that will be the same for all the worse empirical arms
        constant_delta = 1/2.0*np.log(Nmax/(Nmax+1)) 
        constant_delta2 = 0.0
        for index in range(self.Nbarm):
            if index != indexmax:
                Nmin = self.mem['count'][index]
                thetamin    = self.mem['theta'][index]
                result, teqmin = self.compute_teq(Nmax, Nmin, thetamax, thetamin)
                if result:
                    constant_delta2 += 1/(2.0)*erfc(np.sqrt(Nmin)*(teqmin-thetamin)/np.sqrt(2)) 
                    constant_delta += Nmin**(3/2)*(teqmin-thetamin)/(Nmax**(2)*math.sqrt(2*math.pi))*np.exp(-Nmin*(teqmin-thetamin)**2/(2))*(1/4.0*math.log(Nmax/(2*math.pi*math.e)) - 3.0/4.0 + Nmin*(teqmin-thetamin)**2/(4))
                # else the tail is assumed to be equals to zero
        if constant_delta2 > 1 - 1.0/self.Nbarm:
            #there is an issue with the tail decomposition because the probability is greater than 1 then normalise it
            constant_delta2 = 1 - 1.0/self.Nbarm

        constant_delta2 = 1/(2.0*Nmax)*constant_delta2 
        constant_delta += constant_delta2

        for index in range(self.Nbarm):
            if index != indexmax:
                #compute the last part of the increment difference  
                Nk = self.mem['count'][index]
                thetak    = self.mem['theta'][index]
                result, teqk = self.compute_teq(Nmax, Nk, thetamax, thetak)
                if result:
                    deltak = constant_delta + np.exp(-Nk*(teqk-thetak)**2/2)*math.sqrt(1.0/(2*math.pi*Nk))*(teqk-thetak)*(1/4.0*math.log(Nmax/(2*math.pi*math.e))/(Nk) + 1/(2.0) +(teqmin-thetamin)**2/(4))
                else:
                    deltak = constant_delta 
                
                if deltak >0:
                    if next_future_index == -1:
                        next_future_index = index
                    elif deltak > current_best_increment:
                        next_future_index = index
                        current_best_increment = deltak
                
        # if next_future_index == -1 it means that the better empirical arm is the best one
        if next_future_index == -1:
            return indexmax
        else:
            return next_future_index
        
    def compute_teq(self, Nmax, Nmin, thetamax, thetamin):
        if Nmax == Nmin:
            return False, 0
        else:
            try:
                thetaeq = (Nmax*thetamax - thetamin*Nmin)/(Nmax-Nmin) + 1.0/(abs(Nmax-Nmin))*math.sqrt( Nmax*Nmin*(thetamax  - thetamin)**2 + (Nmax  - Nmin)*math.log(Nmax/Nmin))
                return True, thetaeq
            except:
                return False, 0


    def reset(self):
        """
        Reset the algorithm for an another simulation round by reseting the memory of the algorithm.
        """
        super().reset()
        self.mem = {'count': np.zeros(self.Nbarm, dtype=int), 'cumsum': np.zeros(self.Nbarm), 'theta': np.zeros(self.Nbarm)}

class Algo_AIM(Algo_Entropy_Increment):

    """
    This class is a child class of Algo_Entropy_Increment which use the analytic approximation of the entropy increment to get a faster algorithm.
    The entropy increment is computed along the better empirical arm and all the worse empirical arms.
    Please refer to Analytic_Entropy.py for more details about the analytic expressions used to compute the entropy increment.
    """
    def __init__(self, method, K):
        super().__init__(method, K) 
        



