import numpy as np

if __name__ == '__main__':
    print("Warning: this script 'Player.py' is NOT executable..")  # DEBUG
    exit(0)

class Player(object):
    """ Base class for a player class."""

    def __init__(self, param):
        """ 
        Base class for a player class.
        For clarity, we require each child class to re-implement completely the __init__() method.
        """
        self.horizon = param["horizon"]  #: if the horizon is not known in advance, set it to None.
        self.nbArm = param["nbArm"]
                
        #for arm of a specific context-player
        self.context = param["context"]
        self.playerID = param["playerID"]
        
        self.arm_estimate = np.zeros(self.nbArm)
    
    # --- Printing
    def __str__(self):
        return self.__class__.__name__    
        
    # --- functionalities
    def explore(self, context = None, time = None):
        print("decision() should be implemented for agent adopting a particular algorithm.")
        
    
    def learn_arm_value(self, context = None, arm_values = None, collisions = None):
        print("learn_arm_value() should be implemented for agent adopting a particular algorithm.")
        
    def exploit(self, context = None, time=None):
        print("exploit() should be implemented for agent adopting a particular algorithm.")
        
    def reset(self):
        print("reset() should be implemented for agent adopting a particular algorithm.")
   
STATE_EXPLORE = 0
STATE_LEARN = 1
STATE_EXPLOIT = 2

STATE_CONTENT = 0
STATE_HOPEFUL = 1
STATE_WATCHFUL = 2
STATE_DISCONTENT = 3

class GoTPlayer(Player): # with almost the same structure of TnE
    def __init__(self, param):        
        self.horizon = param["horizon"] if "horizon" in param.keys() else 0
                    
        #for arm of a specific context-player
        self.playerID = param["playerID"]
        self.nbArm = param["nbArm"]
        self.nbPlayer = param["nbPlayer"] # used for determining the probaibliy of intermediate state switching
        
        #used in Eq. (10) and Eq. (11) in [Wang2019]
        self.epsilon = param["epsilon"] 
        
        # Initialization
        self.nb_observation = np.zeros(self.nbArm)
        self.accumulated_value = np.zeros(self.nbArm)
        self.arm_estimate = np.zeros(self.nbArm)
        
        self.learning_state = STATE_EXPLORE

        self.selected_arm = 0
        self.nb_state_visit = np.zeros((2, self.nbArm))
        
        self.current_state = [STATE_DISCONTENT, 0]
        
        self.max_u = 1
        self.best_policy = 0
        
        # requirement from [Bistritz2019], the discrepancy of sum of maximum value and the social-optimal value
        self.c = 1.2 # this is an estimation
        self.pert_factor = self.c * self.nbPlayer
        
    def reset(self):
        self.nb_observation = np.zeros(self.nbArm)
        self.accumulated_value = np.zeros(self.nbArm)
            
        # the static game is formulated on arm_estimate
        self.arm_estimate = np.zeros(self.nbArm)
            
        self.learning_state = STATE_EXPLORE         
        
        self.selected_arm = 0
        self.nb_state_visit = np.zeros((2, self.nbArm))
                        
        #set as a default 3-tuple: (mood, reference action, reference payoff = 0 or none-zero)
        self.current_state = [STATE_DISCONTENT, 0] 
#        self.reference_reward = 0 
        
        self.max_u = 1
        self.best_policy = 0
     
    # --- functionalities
    def explore(self, context = None, time = None):
        """
        we will update the estimated arm values in function learn_arm_value()
        context and time are not used for this version
        """
        assert self.learning_state == STATE_EXPLORE, "learning state does not match"#debug
            
        self.selected_arm = np.random.randint(self.nbArm)
        
        return self.selected_arm

    def learn_arm_value(self, context=None, arm_values=None, collisions=None):
        # must be called after explore
        assert self.learning_state == STATE_EXPLORE, "learning state does not match"  # debug
        assert len(arm_values) == self.nbArm and len(collisions) == self.nbArm, "inputs are invalid"
        assert collisions[self.selected_arm] != 0, "arm selection error"

        armID = self.selected_arm

        if collisions[armID] == 1 and arm_values[armID] != 0:
            self.nb_observation[armID] += 1
            self.accumulated_value[armID] += arm_values[armID]

            self.arm_estimate[armID] = self.accumulated_value[armID] / self.nb_observation[armID]

        return self.arm_estimate
    
    def set_internal_state(self, context=None, input_state=STATE_EXPLORE):
        # GoT does not use context information
        # input_state: 0 --explore, 1 -- trial-and-error, 2 -- exploitation
        if input_state < STATE_EXPLORE or input_state > STATE_EXPLOIT:
            raise Exception("input state is invalid")
                
        if input_state == STATE_EXPLORE:
            pass
        elif input_state == STATE_LEARN:
            pass
        elif input_state == STATE_EXPLOIT:
            self.get_best_policy() # calculate once far all
        else:
            raise Exception("input is not valid.")
                            
        self.learning_state = input_state
        
    def initalize_static_game(self, epoch=None, context=None):   
        """
        State initialization is done in init_got_states,
        this function is to be removed in the future
        """
        id_max_u = np.argmax(self.arm_estimate)
        
        self.max_u = self.arm_estimate[id_max_u]

        
    def init_got_states(self, context=None, starting_state=None):
        """
        We have 2 states: Content (C) and Discontent (D).
        For each agent in each context, the total # of local intermediate state is 2 * nbArm
        
        
        starting_state is used for initializing the state at the beginnning of the epoch
        """
        # if we turn (1) on, in each exploration phase the learning algorithm will only use the outcomes of game play in this epoch.
        self.nb_state_visit = np.zeros((2, self.nbArm)) # (1): tracks the frequency of state visits
                
        if starting_state is None:
            # set as a default 3-tuple: (mood=discontent, reference action (arm)=0, reference payoff = 0 or zero)
            self.current_state = [STATE_DISCONTENT, 0]

        else:
            self.current_state = starting_state

    
    def learn_policy(self, context=None, time=None):
        #note that here time is not used   
        assert self.learning_state == STATE_LEARN, "learning state does not match" #debug 
        
        self.selected_arm = self.update_static_game_action(None, self.current_state)
        
        return self.selected_arm            
    
    
    def update_static_game_action(self, context=None, current_state=None):
        """
        Update action in the static game
        """
        if current_state[0] == STATE_CONTENT: # if content
            #content, Eq. (8) Alg.2 of [Bistritz2019], experiment with prob. epsilon
            tmp_factor = self.pert_factor # perturbation factor
            
            # sampling method 1
            prob_no_change = 1 - self.epsilon**(tmp_factor)
            prob_rand_action = self.epsilon**(tmp_factor) / (self.nbArm - 1)
            
            action_array = list(range(self.nbArm))
            prob_array = np.zeros(self.nbArm)
            prob_array[:] = prob_rand_action
            prob_array[current_state[1]] = prob_no_change
                        
            action = np.random.choice(action_array, size=None, p=prob_array)      

                                
        elif current_state[0] == STATE_DISCONTENT: # if discontent
            #discontent
            action = np.random.randint(self.nbArm)
            assert action >=0 and action < self.nbArm, "sampled action is invalid."
        else:
            raise Exception("the mood of the current state is invalid")
            
        return action

    def update_game_state(self, context, collisions, flag_record_frequency=False):
        """
        Ignore any context. The GoT algorithm is designed for the MP-MAB in stochastic environment w/o context
        """
        current_reward = 0 # this is the reward of the static game
        if collisions[self.selected_arm] == 1:
            current_reward = self.arm_estimate[self.selected_arm]
        elif collisions[self.selected_arm] == 0:
            raise Exception("the collision is not correctly computed.") 
        else:
            current_reward = 0 # if there is a collision
        
        if self.current_state[0] == STATE_CONTENT:# if content
            # the current mood is content
            # check the current reward first
            if current_reward <= 0:
                self.current_state[0] = STATE_DISCONTENT
                self.current_state[1] = self.selected_arm
            else:
                # current_reward > 0
                if self.selected_arm == self.current_state[1]:
                    # If the current action is the same as the reference action,
                    # and utility > 0, then a content player remains content with probability 1
                    pass # stay at the same state, w/ probability 1
                elif self.selected_arm != self.current_state[1]:
                    # set the probability
                    threshold = current_reward / self.max_u * (self.epsilon**(self.max_u - current_reward))
                    sampled_result = np.random.choice([0, 1], size=None, p=[threshold, 1-threshold])      
                 
                    if sampled_result == 0:
                        self.current_state[0] = STATE_CONTENT
                        self.current_state[1] = self.selected_arm       

                    else:
                        self.current_state[0] = STATE_DISCONTENT
                        self.current_state[1] = self.selected_arm       

        
        elif self.current_state[0] == STATE_DISCONTENT:
            if current_reward <= 0:
                self.current_state[0] = STATE_DISCONTENT
                self.current_state[1] = self.selected_arm
            else:                
                threshold = current_reward / self.max_u * (self.epsilon**(self.max_u - current_reward))
                sampled_result = np.random.choice([0, 1], size=None, p=[threshold, 1-threshold])
                                 
                if sampled_result == 0:
                    self.current_state[0] = STATE_CONTENT
                    self.current_state[1] = self.selected_arm

                else:
                    self.current_state[0] = STATE_DISCONTENT
                    self.current_state[1] = self.selected_arm
        else:
            raise Exception("unexpected state.")

        # only the last few rounds are considered to count toward the optimal policy
        if flag_record_frequency == True:                
            #update the number of visited states
            id_mood = 0 if self.current_state[0] == STATE_CONTENT else 1
            id_action = self.current_state[1]            

            self.nb_state_visit[id_mood][id_action] = 1 + self.nb_state_visit[id_mood][id_action]
        
    def exploit(self, context = None, time=None):
        assert time is not None, "time is None"
        assert self.learning_state == STATE_EXPLOIT, "learning state does not match at iteration {}".format(time)

        
        self.selected_arm = self.best_policy
        return self.selected_arm #return the action
        
    def get_best_policy(self, context = None):       
        mat_frequency = self.nb_state_visit[0,:] # over the mood axis, over CONTENT
        assert np.shape(mat_frequency) == (self.nbArm,), "shape of frequency is wrong."
                
        id_max = np.argmax(mat_frequency) #over the remaining action/arm axis
        
        self.best_policy = id_max

        
        return id_max