#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from math import isnan
import sys
from pathlib import PurePath
#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())
    


def compute_increment_along_second_index(func, a1,b1,a2,b2):
    """
    Compute the increment of the function given as input along the second arm (arm 2).

    Input:
    - 'a1': The number of successes for the first arm (1 represents rewards).
    - 'b1': The number of failures for the first arm (0 represents no rewards).
    - 'a2': The number of successes for the second arm (1 represents rewards).
    - 'b2': The number of failures for the second arm (0 represents no rewards).

    Output:
    - 'success': A value of 1 indicates that the computation is successful; otherwise, it's set to -1.
    - 'increment': The value of the increment along the second arm.

    Details:
    - The increment is computed by taking the absolute value of the difference between the function's value at the current point and the function's value at the next point, weighted by the empirical mean of arm 2.
    - Note that the increment is calculated only if the second arm is worse than the first arm. Therefore, the arms should have been sorted before with the condition (a1+1)/(a1+b1+2) >= (a2+1)/(a2+b2+2) to ensure proper use.
    
    """

    #check if the second arm is better than the first one
    if (a1+1)/(a1+b1+2) < (a2+1)/(a2+b2+2):
            #print("major issue the second arm is better than the first one")
            return -1, 0

    #compute choice of arm 2
    #arm 2 pulls and wins 
    f1 = func(a1 , b1, a2+1, b2)
    #arms 2 pulls and loses
    f2 = func(a1, b1, a2, b2+1)
    # current state 
    current = func(a1, b1, a2, b2)

    if isnan(f1) or isnan(f2) or isnan(current):
        #print('major issue the real indexes cant be computed because a nan value is observed for (a1,b1,a2,b2)' + str((a1,b1,a2,b2)))
        return -1, 0
    
    return 1, abs(a2/(b2+a2)*f1 + b2/(a2+b2)*f2 -current)

#compare to arms with a given metric function given as input no accuracy asked
def compute_and_display_increment_of_one_index(func, a1,b1,a2,b2):
    """
    Compute the increment of the function, given as input, along the second arm (arm 2) and display detailed information. This function is intended for debugging or pedagogical purposes.

    Input:
    - 'a1': The number of successes for the first arm (1 represents rewards).
    - 'b1': The number of failures for the first arm (0 represents no rewards).
    - 'a2': The number of successes for the second arm (1 represents rewards).
    - 'b2': The number of failures for the second arm (0 represents no rewards).
    - 'func': A function that calculates the selected entropy approximation.

    Output:
    - 'success': A value of 1 indicates that the computation is successful; otherwise, it's set to -1.
    - 'increment': The value of the increment along the second index.

    Details:
    - The increment is computed by taking the absolute value of the difference between the value of the function at the current point and the possible values of the function at the next point, weighted by the empirical mean of arm 2.
    - Note that the increment is calculated only if the second arm is worse than the first one. Therefore, the arms should have been sorted before using the function with the condition (a1+1)/(a1+b1+2) >= (a2+1)/(a2+b2+2).
    
    """

    #check if the second arm is better than the first one
    if (a1+1)/(a1+b1+2) < (a2+1)/(a2+b2+2):
            print("major issue the second arm is better than the first one")
            return -1, 0

    #compute choice of arm one and 2
    #arm 2 pulls and wins
    f1 = func(a1 , b1, a2+1, b2)
    #arms 2 pulls and loses
    f2 = func(a1, b1, a2, b2+1)
    # current state
    current = func(a1, b1, a2, b2)
    print('f1 ,f2, initial:' + str((f1 , f2, current)))

    if isnan(f1) or isnan(f2) or isnan(current):
        print('major issue the real indexes cant be computed because a nan value is observed for (a1,b1,a2,b2)' + str((a1,b1,a2,b2)))
        return -1, 0
    print('increment: ' + str(abs(a1/(b1+a1)*f1 + b1/(a1+b1)*f2 -current)))
    return 1, abs(a2/(b2+a2)*f1 + b2/(a2+b2)*f2 -current)
 
  
def compute_ptail(func, a1,b1,a2,b2):
    """
    Compute the probability of the tail, which is mainly dominated by the worse empirical arm. This function is used when calculating the entropy increment along the better empirical arm.

    Input:
    - 'a1': The number of successes for the first arm (1 represents rewards).
    - 'b1': The number of failures for the first arm (0 represents no rewards).
    - 'a2': The number of successes for the second arm (1 represents rewards).
    - 'b2': The number of failures for the second arm (0 represents no rewards).
    - 'func': A function that calculates the selected entropy approximation.

    Output:
    - The probability of the tail of the worse empirical arm selected.

    Details:
    - At this stage, we assume that the second arm is the worse empirical one. 
    As a result, if 'Nmin' (number of draws for the worse empirical arm) is greater than 'Nmax' (number of draws for the better empirical arm), 
    we set the tail probability to 0. This is done because there is no tail in this case, and 'thetaeq' could result in a negative square root value.
    """
    
    #here amax and bmax are supposed to always be the rewards and number of fails of the current best arm
    if (a1+b1) < (a2+b2):
        return 0
    else:
        return func(a1,b1,a2,b2) 

def compute_armax_increment(fct, a1,b1, ptail_res):
    """
    Compute the increment of a function ('fct') given as input along the first arm (arm 1), which is assumed to be the better empirical arm.

    Input:
    - 'a1': The number of successes for the first arm (1 represents rewards).
    - 'b1': The number of failures for the first arm (0 represents no rewards).
    - 'fct': A function that calculates the selected entropy approximation.
    - 'ptail_res': The sum of the tail probabilities of all the worse empirical arms.

    Output:
    - The value of the increment along the first arm.

    Details:
    - The increment is computed by taking the absolute value of the difference between the value of 'fct' at the current point and the possible values of the function at the next point, weighted by the empirical mean of arm 1.
    
    """

    #compute the increment 
    #arm 1 pulls and wins
    f1 = fct(a1 + 1, b1)
    #arms 1 pulls and loses
    f2 = fct(a1, b1 + 1)
    # current state
    current = fct(a1, b1)
    return abs(a1/(b1+a1)*f1 + b1/(a1+b1)*f2 -current)*abs(1-ptail_res)