#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
from pathlib import PurePath
import numpy as np
import math

#Start to import others .py code from others directory
lcode = PurePath(__file__)
datalocation = lcode.parents[2].as_posix() + '/'
if  (lcode.parents[1].as_posix() in sys.path)  is False:
    sys.path.insert(0,  lcode.parents[1].as_posix())


def sort_and_shift_thetas(a1,b1,a2,b2):
    """
    Sort and adjust 'thetas' to return the mean posterior and the associated number of draws in the correct order.
    
    Input : a1,b1,a2,b2
    
    Output : thetamax, Nmax, thetamin, Nmin
    
    a1 is the number of success for the first arm (1 rewards)
    b1 is the number of failure for the first arm (0 rewards)
    a2 is the number of success for the second arm (1 rewards)
    b2 is the number of failure for the second arm (0 rewards)
    
    The reward distributions are assumed to follow a Bernoulli distribution. 
    The mean posterior, denoted as 'theta,' satisfies the equation (a + 1) / (a + b + 2). 
    'N' satisfies the equation 'a + b + 3' to ensure that the posterior variance equals 'theta * (1 - theta) / N'."

    """
    if a1+b1 == 0 and a2+b2 == 0:
        thetamax = -1
        thetamin = -1
        Nmax = 0
        Nmin = 0 
  
    elif a1+b1 == 0 :
        thetamax = (a2+1)/(a2+b2+2)
        thetamin = -1
        Nmax = a2 + b2 + 3
        Nmin = 0 
 
    elif a2+b2 == 0 :
        thetamax = (a1+1)/(a1+b1+2)
        thetamin = -1
        Nmax = a1 + b1 + 3
        Nmin = 0 
     
        
    elif (a1+1)/(a1+b1+2) > (a2+1)/(a2+b2+2):
        thetamax = (a1+1)/(a1+b1+2)
        thetamin = (a2+1)/(a2+b2+2)
        Nmax = a1 + b1 + 3
        Nmin = a2 + b2 + 3

    elif (a1+1)/(a1+b1+2) == (a2+1)/(a2+b2+2):
        # particular cases where thetas are equals
        # the one with the most draws is the one is returned as the max
        if (a1+b1+2) > (a2+b2+2):
            thetamax = (a1+1)/(a1+b1+2)
            thetamin = (a2+1)/(a2+b2+2)
            Nmax = a1 + b1 + 3
            Nmin = a2 + b2 + 3

        else:
            thetamax = (a2+1)/(a2+b2+2)
            thetamin = (a1+1)/(a1+b1+2)
            Nmax = a2 + b2 + 3
            Nmin = a1 + b1 + 3

    else:
        thetamax = (a2+1)/(a2+b2+2)
        thetamin = (a1+1)/(a1+b1+2)
        Nmax = a2 + b2 + 3
        Nmin = a1 + b1 + 3

    return thetamax, Nmax, thetamin, Nmin, 


def analytic_entropy(a1,b1,a2,b2):
    """
    Calculate the analytical entropy to approximate the entropy of the posterior distribution of the maximum of two arms.
    
    Input : a1,b1,a2,b2
    a1 is the number of success for the first arm (1 rewards) 
    b1 is the number of failure for the first arm (0 rewards)
    a2 is the number of success for the second arm (1 rewards)
    b2 is the number of failure for the second arm (0 rewards)

    Output : entropy approximation value

    Details : 
    The arms are assumed to follow a Bernoulli distribution and are sorted to derive the entropy components.
    The parameter 'thetaeq' is computed, and the entropy is calculated from two components: 'corpse' and 'tail.'
    
    Note :
    If an unusable value of 'thetaeq' is obtained (when 'thetaeq' is greater than 1 or if the square root of a negative number is involved),
    the entropy tail is set to 0, representing 'probatail.'

    """

    # sort arms and get the mean posterior and the associated shifted number of draws
    thetamax, Nmax, thetamin, Nmin = sort_and_shift_thetas(a1,b1,a2,b2)   
  

    #compute thetaeq
    if (a1 + b1 == 0) or (a2 + b2) == 0:
        thetaeq =-1
    elif a1+ b1  == a2 + b2 :
        thetaeq =-1
    else:
        if thetamax == 1:
            thetaeq =-1
        elif (thetamax >= thetamin) and (Nmax >= Nmin):
            try:
                thetaeq_value = thetamax + np.sqrt( (2*thetamax*(1-thetamax)/Nmax)*(1/2.0*np.log(Nmax/Nmin)+ Nmin*((1-thetamin)*np.log((1-thetamin)/(1-thetamax)) + thetamin*np.log(thetamin/thetamax))) )
                if math.isnan(thetaeq_value) is True:
                    thetaeq_value = -1
            except:
                thetaeq_value = -1

            if thetaeq_value > 1.0:
                thetaeq =-1
            else:
                thetaeq = thetaeq_value

        else:
            thetaeq =-1
    
    # get the variance of each posterior distribution
    Vmax = thetamax*(1-thetamax)/(Nmax)
    Vmin = thetamin*(1-thetamin)/(Nmin)
    
    # case thetaeq is well defined
    if thetaeq > 0 :

        #compute the tail probability
        #compute the Kullback divergence
        kullbackdivergence =  thetamin*np.log(thetamin/thetaeq) + (1-thetamin)*np.log((1-thetamin)/(1-thetaeq))
        # get the prefactor due to the Laplace approximation of the tail
        tau = (-thetamin/thetaeq + (1-thetamin)/(1-thetaeq))*Nmin
        probatail = 1.0/np.sqrt(2*math.pi*Vmin)/tau*np.exp(-Nmin*kullbackdivergence)

        # compute the entropy components
        entropy_tail_value = probatail*(Nmin*kullbackdivergence)  
        entropy_corpse_value = 1/2.0*np.log(2*math.pi*Vmax)*(1-probatail)

        return -entropy_tail_value -entropy_corpse_value
   
    # case the tail is taken to 0
    else :
        return -1/2.0*np.log(2*math.pi*Vmax)


def analytic_entropy_armax(amax,bmax):
    """
    The 'analytic_entropy_armax' function returns the entropy body component,
    which depends only on the better empirical arm. 
    This function is used when computing the entropy increment along the better empirical arm.

    Input : amax,bmax
    amax is the number of success for the better empirical arm (1 rewards)
    bmax is the number of failure for the better empirical arm (0 rewards)
    
    Output : entropy body component value 
    
    """

    #get the mean posterior and the associated shifted number of draws
    thetamax = (amax + 1)/(amax + bmax +2)
    Nmax = amax + bmax + 3

    # get the variance of the posterior distribution
    Vmax = thetamax*(1-thetamax)/(Nmax)

    # compute the entropy body component part which depends only of the better empirical arm
    return -1/2.0*np.log(2*math.pi*Vmax)



def analytic_entropy_ptail(a1,b1,a2,b2):
    """
    The 'analytic_entropy_ptail' function returns the probability of the tail, which is primarily influenced by the worse empirical arm.
    This function is used when calculating the entropy increment along the better empirical arm.
    
    Input : a1,b1,a2,b2
    a1 is the number of success for the first arm (1 rewards)
    b1 is the number of failure for the first arm (0 rewards)
    a2 is the number of success for the second arm (1 rewards)
    b2 is the number of failure for the second arm (0 rewards)

    Details : 
    The arms are assumed to follow a Bernoulli distribution and are sorted to derive the entropy components.
    'thetaeq' is computed, and the probability of the tail is derived.
    
    Note :
    If an unusable value of 'thetaeq' is obtained (when 'thetaeq' is greater than 1 or if a square root of a negative number is involved), 'probatail' is set to 0."

    """

    # sort arms and get the mean posterior and the associated shifted number of draws
    thetamax, Nmax, thetamin, Nmin = sort_and_shift_thetas(a1,b1,a2,b2)   
 
    #compute thetaeq
    if (a1 + b1 == 0) or (a2 + b2) == 0:
        thetaeq =-1
    elif a1+ b1  == a2 + b2 :
        thetaeq =-1
    else:

        if thetamax == 1:
            thetaeq =-1
        elif (thetamax >= thetamin) and (Nmax >= Nmin):
            try:
                thetaeq_value = thetamax + np.sqrt( (2*thetamax*(1-thetamax)/Nmax)*(1/2.0*np.log(Nmax/Nmin)+ Nmin*((1-thetamin)*np.log((1-thetamin)/(1-thetamax)) + thetamin*np.log(thetamin/thetamax))) )
                if math.isnan(thetaeq_value) is True:
                    thetaeq_value = -1
            except:
                thetaeq_value = -1

            if thetaeq_value > 1.0:
                thetaeq =-1
            else:
                thetaeq = thetaeq_value

        else:
            thetaeq =-1

    # case thetaeq is well defined
    if thetaeq > 0 :


        # get the variance of the posterior distribution of the the worse empirical arm
        Vmin = thetamin*(1-thetamin)/(Nmin)
        #compute the Kullback divergence
        kullbackdivergence =  thetamin*np.log(thetamin/thetaeq) + (1-thetamin)*np.log((1-thetamin)/(1-thetaeq))
        # get the prefactor due to the Laplace approximation of the tail
        tau = (-thetamin/thetaeq + (1-thetamin)/(1-thetaeq))*Nmin
        # compute the probability of the tail
        return 1.0/np.sqrt(2*math.pi*Vmin)/tau*np.exp(-Nmin*kullbackdivergence)
    
    # case the tail is taken to 0
    else :
        return 0

