#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import math 
import types

# arm class which stores the paramaters of the arm, and its number of rounds, wins and losses
class Arm:
    
    """
    Generates an arm object in a bandit game with the following attributes and methods:

    Parameters:
    -----------
    distr: str, specifying the distribution of the arm (e.g., 'bernoulli' or 'exact').
    param: float, specifying the mean of the arm if Bernoulli or the exact mean of the arm if 'exact'.
    generator_seeded: numpy random generator seeded for the arm object to draw random rewards.
    wins: int, current number of wins of the arm.
    losses: int, current number of losses of the arm.
    rounds: int, current number of rounds of the arm (the number of times it has been pulled).

    Methods:
    --------
    sample: Samples from the distribution and updates the history.
    enforce_state: Enforces the state of the arm with the input values (for testing or debugging).
    enforce_reward: Adds a win if the win input is equal to 1; otherwise, adds 1 loss (for testing or debugging).
    get_average: Returns the current mean (not the mean of the posterior).
    get_std: Returns the unbiased estimator of the standard deviation of the empirical mean (not the posterior one).
    get_order3: Returns the order 3 around which the Gaussian approximation is no longer valid (for testing or debugging).
    reset: Resets the arm, keeping the same distribution and mean value.
    none_null_theta: Returns True if the arm has been pulled at least once.
    none_zero_var: Returns True to test if the arm has been pulled sufficiently to have a nonzero variance.
    show: Prints the current parameters of the arm object (for testing or debugging).

    Notes:
    - 'bernoulli' is mostly used for the simulation; 'exact' is used for debugging purposes.
    
    """


    def __init__(self, distr, param, generator_seeded):
        """

        Initializes the arm object with the specified distribution, mean value, and seeded generator.

        Parameters:
        -----------
        distr: str, specifying the distribution of the arm (e.g., 'bernoulli' or 'exact').
        param: float, specifying the mean of the arm if Bernoulli or the exact mean of the arm if 'exact'.
        generator_seeded: numpy random generator seeded for the arm object to draw random rewards.

        """
        self.distr = distr
        self.param = param
        #self.generator = rnd.default_rng(seed=seed)
        self.generator = generator_seeded
        
        #distribution can be bernoulli or exact 
        if distr == 'bernoulli':
            #assert(type(self.param) == np.float64)
            self.mean = param
        
        if distr == 'exact':
            #if the distribution is exact the arm will try to stay the closest of self.mean adapting its number of wins and losses.
            #Thus their is no randomness for exact distribution
            #assert(type(self.param) == np.float64)
            self.mean = param

        # summary variables of the arm history
        self.wins = 0
        self.losses = 0
        self.rounds = 0
    
    @classmethod
    def fromarm(cls, arm):
        """
        Initializes the arm object with the distribution, mean value, and seeded generator from another arm object.

        Parameters:
        -----------
        other_arm: Arm object, providing the distribution, mean value, and seeded generator.
        """
        return cls(arm.distr, arm.param, arm.generator_seeded)
    
    # sample from the distribution and update history
    def sample(self):

        """
        Samples the arm from the distribution, updates its history, and returns 1 if the arm wins and 0 if it loses.
        the history is not used by the arm itself but could be used by the bandit object or for debuggin, plotting etc...     

        Returns:
        int: 1 if the arm wins, 0 if it loses.
        """

        #case of bernoulli distribution
        self.rounds += 1
        if self.distr == "bernoulli":
            #win
            if self.generator.random() < self.param:
                self.wins += 1
                return 1.0
            #loss
            else:
                self.losses += 1
                return 0.0
            
        # case of exact distribution
        if self.distr == 'exact':
            if self.wins + self.losses == 0:
                self.wins += int(self.param*2)
                self.losses += 1 - int(self.param*2)
                return int(self.param*2)
                
            if self.wins/( self.wins + self.losses) <  self.param:
                self.wins += 1 
                return 1.0
            else:
                self.losses += 1
                return 0.0
        
        raise ValueError('Unknown distr specified')
    

    def enforce_state(self, *args):
       
        """
        Enforces the state of the arm with the input values for testing or debugging purposes.

        Parameters:
        args: A list or two values specifying the number of wins and losses to set for the arm.

        Usage:
        - Provide either a list with two values or two separate values, representing the number of wins and losses.
        - Use this method to set the arm's state to specific values.

        """
        
        if len(args) == 1:
            self.wins = args[0][0]
            self.losses =args[0][1]
            self.rounds = args[0][0] + args[0][1]
        else:
            self.wins = args[0]
            self.losses =args[1]
            self.rounds = args[0] + args[1]
        
    # add a win if win input is equal to 1 return 1 loss else
    def enforce_reward(self, win):
        """
        Adds a win if the input win is equal to 1; otherwise, returns 1 loss.

        Parameters:
        win: int, should be 0 or 1.

        Returns:
        int: 1 if the arm wins, 0 if it loses.
        """

        self.rounds += 1
        self.wins += win
        self.losses += 1 - win

    
    # return the current mean 
    def get_average(self):
        """
        Returns the current mean (not the mean of the posterior).

        Returns:
        float: The current mean of the arm.
        """

        if (self.losses + self.wins) > 0:
            return self.wins /(self.losses + self.wins)
        else:
            return -1
    
    #return the unbiased estiamtor of the std 
    def get_std(self):
        """
        Returns the unbiased estimator of the standard deviation of the empirical mean (not the posterior one).

        Returns:
        float: The unbiased estimator of the standard deviation.
        """
        if (self.losses + self.wins) != 1 :
             value = math.sqrt((self.wins /(self.losses + self.wins))*(1-self.wins /(self.losses + self.wins))/(self.losses + self.wins -1))
             return value
        else:
            return 0.0
            # in the case of restless bandits, update the reward distribution
    
    # get order 3 around  which the gaussian approximation is no more valid
    def get_order3(self):
        """
        Returns the order 3 around which the Gaussian approximation is no longer valid (for testing or debugging purposes).

        Returns:
        float: The order 3 value.
        """
        return (1.0 /(self.losses + self.wins))**(1.0/3.0)
    
    # reset the arm keeping the same distribution and value
    def reset(self):
        """
        Resets the arm while keeping the same distribution and mean value.
        """
        self.wins = 0
        self.losses = 0
        self.rounds = 0
    
    def none_null_theta(self):
        """
        Tests if the arm has been pulled at least once.

        Returns:
        bool: True if the arm has been pulled at least once, False otherwise.
        """
        return self.wins +self.losses != 0
    
    
    def none_zero_var(self):
        """
        Tests if the arm has been pulled sufficiently to have a nonzero variance.

        Returns:
        bool: True if the arm has been pulled sufficiently to have a nonzero variance, False otherwise.
        """
        return self.wins != 0 and self.losses != 0
     
    def show(self, methods=False):
        
        """
        Prints the current parameters of the arm object for testing or debugging purposes.

        Parameters:
        methods: bool, if True, the methods of the arm object are also printed.

        Usage:
        - Call this method to display the current parameters of the arm object.
        - Set 'methods' to True if you want to print the methods as well.
        
        """
        
        print('arm object with current arguments:')
        list_to_print = dir(self)
        list_to_print = [name for name in list_to_print if name[0:2] != '__']
        list_arguments = []
        list_methods = []
        
        for name in list_to_print :
            fullname = 'self.' + name
            current_value = eval(fullname)
            if isinstance(current_value, types.MethodType):
                list_methods.append( str(fullname) + ': ' + str(eval(fullname)))
            else:
                list_arguments.append(str(fullname) + ': ' + str(eval(fullname)))
        for cstring in list_arguments:
            print(cstring)
        if methods == True:
            print('arm object with current methods:')
            for cstring in list_methods:
                print(cstring)
                
if __name__ == '__main__':
    
    import numpy.random as rnd
    from numpy.random import SeedSequence

    """
    some test for the arm class
    """

    distr = 'bernoulli'
    param = 0.6
    Nsteps = 20
    
    global_seed_sequence = SeedSequence()
    seeds = global_seed_sequence.generate_state(1)
    seed1= seeds[0]
    simu_generator_arm1 = rnd.default_rng(seed=seed1)
    
    arm1 = Arm(distr, param,  simu_generator_arm1)
    
    print('Is theta non null at initialization:' + str(arm1.none_null_theta()))
    for k in range(Nsteps):
        arm1.sample()
        arm1.show()
        print('')
    print('')
    print('Is theta non null after nsteps :' + str(arm1.none_null_theta()))
    
    print('test average std etc..')
    print('average : ' + str( arm1.get_average()))
    print('std: ' + str(arm1.get_std()))
    print('1/N^3: ' + str(arm1.get_order3()))
    
    print('')
    arm1.enforce_reward(0)
    arm1.show()        
    print('')
    arm1.enforce_reward(1)
    arm1.show() 
    print('')   
    
    print('test reset')
    arm1.reset()
    arm1.show()
    print('')
    
    arm1.enforce_state([15,12])
    arm1.show()
    print('')
    arm1.enforce_state(18,22)
    arm1.show()
    print('')