#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import math 
import types

# arm class which stores the paramaters of the arm, and its number of rounds, wins and losses
class Arm:
    
    """
    Generates an arm object in a bandit game with the following attributes and methods:

    Parameters:
    -----------
    distr: str, specifying the distribution of the arm (e.g., 'gaussian' or 'exact').
    param: float, specifying the mean of the arm if Gausian or the exact mean of the arm if 'exact'. The std is always 1.
    generator_seeded: numpy random generator seeded for the arm object to draw random rewards.
    rounds: int, current number of rounds of the arm (the number of times it has been pulled).

    Methods:
    --------
    sample: Samples from the distribution and updates the history.
    enforce_state: Enforces the state of the arm with the input values (for testing or debugging).
    enforce_reward: Adds a win if the win input is equal to 1; otherwise, adds 1 loss (for testing or debugging).
    get_average: Returns the current mean (not the mean of the posterior).
    get_std: Returns the unbiased estimator of the standard deviation of the empirical mean (not the posterior one).
    reset: Resets the arm, keeping the same distribution and mean value.
    none_null_theta: Returns True if the arm has been pulled at least once.
    show: Prints the current parameters of the arm object (for testing or debugging).

    Notes:
    - 'gaussian' is mostly used for the simulation; 'exact' is used for debugging purposes.
    
    """


    def __init__(self, distr, param, generator_seeded):
        """

        Initializes the arm object with the specified distribution, mean value, and seeded generator.

        Parameters:
        -----------
        distr: str, specifying the distribution of the arm (e.g., 'gaussian' or 'exact').
        param: float, specifying the mean of the arm if Gaussian or the exact mean of the arm if 'exact'.
        generator_seeded: numpy random generator seeded for the arm object to draw random rewards.

        """
        self.distr = distr
        self.param = param
        #self.generator = rnd.default_rng(seed=seed)
        self.generator = generator_seeded
        
        #distribution can be gaussian or exact 
        if distr == 'gaussian':
            #assert(type(self.param) == np.float64)
            self.mean = param

        if distr == 'exact':
            #if the distribution is exact the arm will try to stay the closest of self.mean adapting its number of wins and losses.
            #Thus their is no randomness for exact distribution
            #assert(type(self.param) == np.float64)
            self.mean = param

        # summary variables of the arm history
        self.rewards = 0
        self.rounds = 0
    
    @classmethod
    def fromarm(cls, arm):
        """
        Initializes the arm object with the distribution, mean value, and seeded generator from another arm object.

        Parameters:
        -----------
        other_arm: Arm object, providing the distribution, mean value, and seeded generator.
        """
        return cls(arm.distr, arm.param, arm.generator_seeded)
    
    # sample from the distribution and update history
    def sample(self):

        """
        Samples the arm from the distribution, updates its history, and returns 1 if the arm wins and 0 if it loses.
        the history is not used by the arm itself but could be used by the bandit object or for debuggin, plotting etc...     

        Returns:
        int: 1 if the arm wins, 0 if it loses.
        """

        #case of gaussian distribution
        self.rounds += 1
        if self.distr == "gaussian":
            return self.generator.normal(self.mean, 1.0)
            
        # case of exact distribution
        if self.distr == 'exact':
            return self.mean
        
        
        raise ValueError('Unknown distr specified')
    

    def enforce_state(self, *args):
       
        """
        Enforces the state of the arm with the input values for testing or debugging purposes.

        Parameters:
        args: A list or two values specifying the rewards and rounds to set  the arm.

        Usage:
        - Provide either a list with two values or two separate values, representing the rewards and rounds, respectively.
        - Use this method to set the arm's state to specific values.

        """
        if len(args) == 1:
            self.rewards = args[0][0]
            self.rounds = args[0][1] 
        else:
            self.rewards= args[0]
            self.rounds = args[1]

    def enforce_reward(self, reward):
        """
    
        Parameters:
        reward ; float : the reward to enforce

        Returns:
        int: 1 if the arm wins, 0 if it loses.
        """

        self.rounds += 1
        self.rewards = reward

    
    # return the current mean 
    def get_average(self):
        """
        Returns the current mean (not the mean of the posterior).

        Returns:
        float: The current mean of the arm.
        """

        if self.rounds > 0:
            return self.rewards / self.rounds
        else:
            return None
    
    #return the unbiased estiamtor of the std 
    def get_std(self):
        """
        Returns the unbiased estimator of the standard deviation of the empirical mean (not the posterior one).

        Returns:
        float: The unbiased estimator of the standard deviation.
        """
        if (self.rounds) != 1 :
             value = 1.0/math.sqrt(self.rounds-1)
             return value
        else:
            return 0.0
            # in the case of restless bandits, update the reward distribution
    
 
    
    # reset the arm keeping the same distribution and value
    def reset(self):
        """
        Resets the arm while keeping the same distribution and mean value.
        """
        self.rewards = 0
        self.rounds = 0
    
    def none_null_theta(self):
        """
        Tests if the arm has been pulled at least once.

        Returns:
        bool: True if the arm has been pulled at least once, False otherwise.
        """
        return self.rounds != 0
    
    
  
    def show(self, methods=False):
        
        """
        Prints the current parameters of the arm object for testing or debugging purposes.

        Parameters:
        methods: bool, if True, the methods of the arm object are also printed.

        Usage:
        - Call this method to display the current parameters of the arm object.
        - Set 'methods' to True if you want to print the methods as well.
        
        """
        
        print('arm object with current arguments:')
        list_to_print = dir(self)
        list_to_print = [name for name in list_to_print if name[0:2] != '__']
        list_arguments = []
        list_methods = []
        
        for name in list_to_print :
            fullname = 'self.' + name
            current_value = eval(fullname)
            if isinstance(current_value, types.MethodType):
                list_methods.append( str(fullname) + ': ' + str(eval(fullname)))
            else:
                list_arguments.append(str(fullname) + ': ' + str(eval(fullname)))
        for cstring in list_arguments:
            print(cstring)
        if methods == True:
            print('arm object with current methods:')
            for cstring in list_methods:
                print(cstring)
                
