import numpy as np
from numfi import numfi
from scipy.optimize import linear_sum_assignment

from MABAlgorithms import MABAlgorithm
from Players2a import ESEPlayer

if __name__ == '__main__':
    print("Warning: this script 'MABAlgorithms2a.py' is NOT executable...")  # DEBUG
    exit(0)

"""
 Algorithm: Explore-Signal-Exploit (ESE)
"""    
class ESE(MABAlgorithm):

    def __init__(self, param):
        super().__init__(param)
        self.nbPlayer = param["nbPlayer"]
        self.nbArm = param["nbArm"]
        self.context_set = param["context_set"] # not used
        
        self.delta_R = 0.1 if 'delta_R' not in param.keys() else param['delta_R'] # allowable prob. of non-orthogonal allocation
        self.epsilon = 0.1 if 'epsilon' not in param.keys() else param['epsilon']
        self.time = 0
        self.last_t = 0
        self.new = 0

        self.Ts = np.ceil(8.*self.nbPlayer**2/(self.epsilon**2))
        self.Tb = np.ceil(np.log2(4.*self.nbPlayer/self.epsilon))
        
        self.current_epoch = 1

        self.agents = []        
        for playerID in range(self.nbPlayer):
            player_param = {"context_set": self.context_set, 
                            "nbArm": self.nbArm,
                            "nbPlayer": self.nbPlayer,
                            "playerID": playerID
                            }
            
            self.agents.append(ESEPlayer(player_param))

    # --- Printing
    def __str__(self):
        return "Explore Signal Exploitn"

    # --- functionalities
    def reset(self, horizon=None):
        self.time = 0
        self.current_epoch = 1
        
        for agent in self.agents:
            agent.reset()

    def learn_policy(self, game_env, context=None, time=None):
        # context is not used in ESE
        (nbPlayer, nbArm) = np.shape(game_env)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."
        assert nbPlayer <= nbArm, "player number should be larger than or equal to arm number."

        self.time = self.time + 1


        pulls = np.zeros((nbPlayer, nbArm))

        if self.time <= self.nbArm * self.Ts:
            #exploration with sequential hopping
            for agentID in range(nbPlayer):
                armID = self.agents[agentID].sequential_hop(None, time)
                pulls[agentID][armID] = 1

            collisions = self.resolve_collision(pulls)
            for agentID in range(nbPlayer):
                self.agents[agentID].learn_arm_value(None, game_env[agentID,:], collisions)
            # print(f'self.nbArm * self.Ts={self.nbArm * self.Ts}, arm {self.nbArm}, Ts {self.Ts}')
            if self.time == self.nbArm * self.Ts:
                print('i')
                self.last_t = self.time
                arm_matrix = np.zeros((self.nbPlayer, self.nbArm))
                for agentID in range(nbPlayer):
                    arm_matrix[agentID, :] = self.agents[agentID].arm_score

                truncated_arm_matrix = numfi(arm_matrix, bits_frac=int(np.log2(4 * self.nbPlayer / self.epsilon)))

                for agentID in range(nbPlayer):
                    self.agents[agentID].estimated_arm_matrix = truncated_arm_matrix
                    self.agents[agentID].estimated_arm_matrix[agentID, :] = self.agents[agentID].arm_score

                    # each player performs local Hungarian algorithm to derive its "optimal" policy
                    # the mehtod requires the number of rows (jobs) to be larger than that of columns (workers)
                    cost_matrix = np.negative(self.agents[agentID].estimated_arm_matrix.transpose())
                    # note that the cost_matrix is a transpose of the original matrix
                    col_ind, row_ind = linear_sum_assignment(cost_matrix)

                    # set player's policy
                    for ii in range(len(row_ind)):

                        playerID = row_ind[ii]
                        if playerID == agentID:
                            self.agents[agentID].policy = col_ind[ii]
                            print(f'arm policy = {col_ind[ii]}')
                            pulls[agentID][col_ind[ii]] = 1
            # print(f'self.time - self.last_t = {self.time - self.last_t}, last = {self.last_t}')
        elif self.time - self.last_t <= self.nbPlayer * self.nbArm * self.Tb:

            for agentID in range(self.nbPlayer):
                print('?')
                armID = self.agents[agentID].policy
                pulls[agentID][armID] = 1

            collisions = self.resolve_collision(pulls)
            if self.time - self.last_t == self.nbPlayer * self.nbArm * self.Tb:
                self.last_t = self.time

        elif self.time - self.last_t <= int(np.exp(self.current_epoch)):

            # exploitation phase
            for agentID in range(nbPlayer):
                armID = self.agents[agentID].exploit(context, self.time)
                pulls[agentID][armID] = 1

            collisions = self.resolve_collision(pulls)

            if self.time - self.last_t == int(np.exp(self.current_epoch)):
                #update round number
                self.last_t = self.time
                self.current_epoch = self.current_epoch + 1

        current_rewards = self.observe_distributed_payoff(game_env, collisions)
        total_rewards = np.sum(current_rewards)
        return pulls, total_rewards, current_rewards
    
# add other algorithms here
__all__ = ["ESE"]        