import numpy as np
from Players import Player

if __name__ == '__main__':
    print("Warning: this script 'PlayerResult2.py' is NOT executable..")  # DEBUG
    exit(0)

class ESEPlayer(Player):
    """ 
    ESEPlayer is the player for the algorithm "Explore-Signal-Exploit (ESE)" proposed in
    "Multiplayer multiarmed bandits for optimal assignment in heterogeneous networks," arXiv preprint arXiv:1901.03868, 
    by Sumit J. Darak and Manjesh K. Hanawal [Tibrewal2019].
        
    The algorithm is featured by a protocol using player's state to carry the data load for signaling
    """
    def __init__(self, param):
        # super().__init__(param)
        self.nbArm = param["nbArm"]
        self.nbPlayer = param["nbPlayer"]
        
#        info_logger().log_info('ESE player number {}'.format(self.nbPlayer)) #debug

        self.playerID = param["playerID"]
        
        self.flag_lock = False
    
        self.selected_arm = 0 # index of the locked arm
        self.policy = -1 # set to an invalid value
        self.current_arm = self.playerID
        self.arm_score = np.zeros(self.nbArm)
        self.estimated_arm_matrix = np.zeros((self.nbPlayer, self.nbArm))
        
        self.accumulated_value = np.zeros(self.nbArm)
        self.nb_observation = np.zeros(self.nbArm) # number of observed non-zero payoff

    def reset(self):
        self.flag_lock = False
    
        self.selected_arm = 0
        self.policy = -1 # set to an invalid value
        self.current_arm = 0
        
        self.arm_score = np.zeros(self.nbArm) 
        self.estimated_arm_matrix = np.zeros((self.nbPlayer, self.nbArm))
        
        self.accumulated_value = np.zeros(self.nbArm)
        self.nb_observation = np.zeros(self.nbArm) # number of observed non-zero payoff

    def sequential_hop(self, context = None, time = None):
        self.current_arm = (self.current_arm + 1) % self.nbArm

        return self.current_arm

    def learn_arm_value(self, context=None, arm_values=None, collisions=None):
        # estimated arm score
        armID = self.current_arm

        if collisions[armID] == 1 and arm_values[armID] != 0:
            self.nb_observation[armID] += 1
            self.accumulated_value[armID] += arm_values[armID]

        # update UCB Scores
        self.arm_score = self.accumulated_value / (self.nb_observation + 1e-9)
        # print(f'arm score = {self.arm_score}')

    def exploit(self, context = None, time=None):
        self.current_arm = self.policy
        return self.current_arm
