#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import numpy as np
import types


# regret class which stores the arms' parameters 
class Regret:
    """

    The Regret object stores parameters for the regret value and the mean reward of the arms. 
    It is used to compute and track the regret value throughout the bandit game. 
    Additionally, it keeps count of how many times suboptimal arms were selected, mainly for debugging and pedagogical purposes.

    Parameters:
    -----------
    bestreward: float, mean value of the best arm.
    bestindex: int, index of the best arm.
    thetas: list of float, parameter values of different arms.
    regret: float, current regret value.
    time: int, total number of steps drawn for the bandit (used for debugging and pedagogical purposes).
    Nwrong: int, number of times an suboptimal arm has been selected (used for debugging and pedagogical purposes).

    Methods:
    --------
    update: Update the regret value when an arm is pulled.
    get_Delta: Return the difference between the mean value of the best arm and the mean value of an arm.
    reset: Reset the Regret object for a new simulation.
    enforce_state: Set the state of the Regret object with input values (for testing or debugging).
    show: Display the current parameters of the Regret object (for testing or debugging).

    """
    def __init__(self, thetas):
        
        """
        Initialize the regret object with the mean value of the arms.
        
        """
        self.bestreward = max(thetas) #best arm mean  value
        self.bestindex = np.argmax(thetas) # index of the best arm
        self.thetas = thetas #parameter values of the different arms
        self.regret = 0 #regret current value
        self.time = 0  #total number of steps drawn for the bandit
        self.Nwrong = 0 #Number of times an unoptimal arm has been selected 

    @classmethod
    def fromregret(cls, regret):
        """
        Initialize the regret object using the mean values of another regret object
        """
        return cls(regret.thetas)
        
        
    def update(self, index):
        """
        Update the regret value after pulling an arm. 
        
        The 'index' parameter should indicate which arm was pulled.
        
    
        """
        #update the regret value and time 
        self.time +=1 
        if self.bestindex != index:
            self.regret += self.bestreward - self.thetas[index] #update regret because a weaker arm has been chosen 
            self.Nwrong += 1
            return 1  #warns the other functions that the selected arm was not optimal (debugging purpose not used by the bandit algorithms)
        else:
            return 0  #warns the other functions that the selected arm was optimal
        
    def get_Delta(self, index):
        """
        Calculate the difference between the mean value of the best arm and the mean value of the arm 
        to determine the prefactor used in the calculation of the number of suboptimal draws
        """
    
        return self.bestreward - self.thetas[index]
    
    def reset(self):
        """
        Reset the regret object for a new simulation.
        """
        self.regret = 0
        self.time = 0 
        self.Nwrong = 0
        
    def enforce_state(self, *args):
        """
        Set the state of the regret object with input values, primarily for testing or debugging purposes. 
        If there are two arguments, the first one represents the regret value, and the second one represents the total number of steps drawn in the bandit game. 
        If there is only one argument, it should be a list indicating the number of times each arm has been pulled.
        """
        if len(args) == 2:
            self.regret = args[0] #regret current value
            self.time = args[1]
            self.Nwrong = 0
        
        if len(args) == 1:
            self.time = 0
            self.regret = 0
            for index in range(len(args[0])):
                self.time += args[0][index][0] + args[0][index][1]
                if index != self.bestindex:
                    self.regret += (args[0][index][0] + args[0][index][1])*(self.bestreward - self.thetas[index])
                    self.Nwrong = args[0][index][0] + args[0][index][1]
    
    def show(self, methods=False):
        """
        Display the current parameters of the regret object for testing or debugging purposes.
        """
        print('regret object with current arguments:')
        list_to_print = dir(self)
        list_to_print = [name for name in list_to_print if name[0:2] != '__']
        list_arguments = []
        list_methods = []
        
        for name in list_to_print :
            fullname = 'self.' + name
            current_value = eval(fullname)
            if isinstance(current_value, types.MethodType):
                list_methods.append( str(fullname) + ': ' + str(eval(fullname)))
            else:
                list_arguments.append(str(fullname) + ': ' + str(eval(fullname)))
        for cstring in list_arguments:
            print(cstring)
        if methods == True:
            print('regret object with current methods:')
            for cstring in list_methods :
                print(cstring)
 
