from collections import deque
import numpy as np
from Players import GoTPlayer

class MABAlgorithm(object):
    """ Base class for an algorithm class."""

    def __init__(self, param):
        """ Base class for an algorithm class."""
        self.nbPlayer = param["nbPlayer"]
        self.nbArm = param["nbArm"]
        self.horizon = param["horizon"]

        self.nbAgent = 0  # number of agents in the algorithms, can be centralized, decentralized or partially decentralized

        # an agent is usually corresponding to a player, it has its own
        self.agents = []

        delay_mean = 50
        delay_std = 30
        # 设置高斯分布的延迟
        self.delays = np.abs(np.random.normal(delay_mean, delay_std, size=(self.nbPlayer, self.horizon))).astype(int)

        # 使用 deque 存储延迟的奖励和碰撞信息
        self.queues = [[deque() for _ in range(self.horizon)] for _ in range(self.nbPlayer)]

    # --- Printing
    def __str__(self):
        return self.__class__.__name__

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, self.__dir__)

    # --- functionalities
    def resolve_collision(self, pulls):
        (nbPlayer, nbArm) = np.shape(pulls)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."
        assert nbPlayer <= nbArm, "player number should be larger than or equal to arm number."

        collisions = pulls.sum(axis=0)

        assert len(collisions) == nbArm, "dimension of collisions is incorrect"
        return collisions

    def learn_policy(self, game_env, context=None, time=None):
        """
        Learn policies based on the given game environments.
        A game environment can be in the form of (context, sampel_reward_matrix)
        """
        raise NotImplementedError(
            "This method learn_policy(t) has to be implemented in the class inheriting from MABAlgorithm.")

    def reset(self, horizon=None):
        """
        The rest parameters cannot be reset, except self.horizon.
        """
        raise NotImplementedError(
            "This method reset() has to be implemented in the class inheriting from MABAlgorithm.")

    def pulls2choices(self, pulls):
        """
        Convert pulls into choices
        """
        (nbPlayer, nbArm) = np.shape(pulls)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."

        arm_choices = np.zeros(nbPlayer, dtype=int)

        arm_selected = np.nonzero(pulls)  # index of non-zero values

        # for some algorithms there may be a case when a player refuse to choose any arm
        for index in range(len(arm_selected[0])):
            playerID = arm_selected[0][index]
            arm_choices[playerID] = arm_selected[1][index]  # playerID should be 0, 1, 2,..., nbPlayer-1

        return arm_choices

    def observe_distributed_payoff(self, game_env, collisions):
        (nbPlayer, nbArm) = np.shape(game_env)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."

        current_reward = np.zeros(self.nbPlayer)

        for playerID in range(self.nbPlayer):
            selected_arm = self.agents[playerID].selected_arm

            # for some algorithms there may be a case when a player refuses to choose any arm
            if selected_arm < 0:
                current_reward[playerID] = 0
            else:
                if collisions[selected_arm] == 1:
                    current_reward[playerID] = game_env[playerID][selected_arm]  # not collidiing
                else:
                    current_reward[playerID] = 0  # colliding or void

        # returen an array of dimension nbArm
        return current_reward


class GameofThrone(MABAlgorithm):
    """
    Decentralized assignment algorithm in the form of game-of-throne learning algorithm.
    Implemented for the paper "Distributed Multi-Player Bandits - a Game of Thrones Approach", by Ilai Bistritz et al.
    Note that this algorithm is designed for multi-player without considering contextual information.
    """

    def __init__(self, param):
        super().__init__(param)  # 确保调用父类的__init__方法
        self.nbPlayer = param["nbPlayer"]
        self.nbArm = param["nbArm"]
        self.horizon = param["horizon"]  # agents don't know the fixed horizon when running the algorithm

        # each player will be attached a single agent
        #        self.nbAgent = self.nbPlayer
        self.agents = []

        # a large epsilon will leads to more frequent transtions (explorations) in the intermedate game
        self.epsilon = param["epsilon"] if "epsilon" in param.keys() else 0.1
        # see Theorem 1 in [Wang2019], not kept by the agents, determining trial-and-error rounds
        self.delta = param["delta"] if "delta" in param.keys() else 2
        # set the round of iteration where we
        self.rho = param["rho"] if "rho" in param.keys() else 0.5

        self.c1 = param["c1"]
        self.c2 = param["c2"]
        self.c3 = param["c3"]

        for playerID in range(self.nbPlayer):
            player_param = {"nbArm": self.nbArm,
                            "nbPlayer": self.nbPlayer,
                            "playerID": playerID,
                            "epsilon": self.epsilon,
                            "delta": self.delta
                            }

            self.agents.append(GoTPlayer(player_param))

        self.time = 0
        # used for determining the epoch
        self.epoch = 1

        # initialize for the first epoch,
        # for simplicity, the parameter names are kept the same as the TnE algorithm.
        self.exploration_round = self.c1
        self.got_round = self.exploration_round + self.c2  # *1
        self.rounds_in_epoch = self.got_round + self.c3 * 2  # * (2** 1) # rounds in the first epoch
        self.current_round = 1

        # --- Printing

    def __str__(self):
        return "Game of Throne"

        # --- functionalitiess

    def reset(self, horizon=None):
        for agent in self.agents:
            agent.reset()

        self.time = 0
        self.epoch = 1

        # reset to the initial values
        self.got_round = self.exploration_round + self.c2  # *1
        self.rounds_in_epoch = self.got_round + self.c3 * 2  # * (2** 1) # rounds in the first epoch
        self.current_round = 1

        if horizon is not None:
            self.horizon = horizon

    def learn_policy(self, game_env, context=None, time=None):
        """
        learn_policy() implements the 3 phases in Alg. 1 of [Leshem2018].
        Implemented in the same structure for tial-and-error
        """
        (nbPlayer, nbArm) = np.shape(game_env)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."
        assert nbPlayer <= nbArm, "player number should be larger than or equal to arm number."
        assert time is not None, "time is not given."

        self.time = self.time + 1

        if self.current_round > self.rounds_in_epoch:
            # update epcoh
            self.epoch = self.epoch + 1
            # rounds in the k-th epoch
            self.exploration_round = int(self.c1 * (self.epoch ** self.delta))
            self.got_round = int(self.exploration_round + self.c2 * (self.epoch ** self.delta))
            self.rounds_in_epoch = int(self.got_round + self.c3 * (2 ** self.epoch))
            # reset
            self.current_round = 1
        #            print("number of epoch: {}".format(self.epoch))# debug

        pulls = np.zeros((nbPlayer, nbArm))

        if self.current_round <= self.exploration_round:  # exploration rounds
            # reset the phase to exploration in an epoch
            if self.current_round == 1:
                for agentID in range(nbPlayer):
                    self.agents[agentID].set_internal_state(context=None, input_state=0)

                    # exploration by randomly choosing actions
            for agentID in range(nbPlayer):
                armID = self.agents[agentID].explore(None, time)
                pulls[agentID][armID] = 1

            collisions = self.resolve_collision(pulls)
            for agentID in range(nbPlayer):
                self.agents[agentID].learn_arm_value(None, game_env[agentID, :], collisions)

            # learn the real payoff
            current_rewards = self.observe_distributed_payoff(game_env, collisions)

        elif self.current_round <= self.got_round:  # game-and-thrones phase
            if self.current_round == self.exploration_round + 1:
                # reset the phase to learning in an epoch
                for agentID in range(nbPlayer):
                    self.agents[agentID].set_internal_state(context=None, input_state=1)

                    # as per Alg.1 in [Leshem2018], initialize the mood to be content
                    if self.epoch != 1:
                        init_state = [0, self.agents[agentID].best_policy]  # (STATE_CONTENT, BEST ACTION)
                    else:
                        # randomize
                        action = np.random.randint(self.nbArm)
                        init_state = [0, action]

                    # initialize the intermediate game
                    self.agents[agentID].initalize_static_game(init_state, self.epoch)

                    self.agents[agentID].init_got_states(context=None, starting_state=init_state)

                    # game of throne phase, taking actions randomly according to the intermediate state
            for agentID in range(nbPlayer):
                armID = self.agents[agentID].learn_policy(context=None)
                pulls[agentID][armID] = 1

            collisions = self.resolve_collision(pulls)

            flag_count_frequency = False
            # update the count of state-visit only for the last half starting from rho*c2*k^(1+delta) rounds
            #            if self.current_round >= self.got_round - 1 - self.rho*self.c2*(self.epoch**self.delta):
            if self.current_round >= self.exploration_round + self.rho * self.c2 * (self.epoch ** self.delta):
                flag_count_frequency = True
            #            flag_count_frequency = True

            for agentID in range(nbPlayer):
                self.agents[agentID].update_game_state(context=None, collisions=collisions,
                                                       flag_record_frequency=flag_count_frequency)

            # update reward according to actions taken
            current_rewards = self.observe_distributed_payoff(game_env, collisions)

        else:
            if self.current_round == self.got_round + 1:
                # reset the phase to exploitation in an epoch
                for agentID in range(nbPlayer):
                    # the best policy is computed in set_internal_state()
                    self.agents[agentID].set_internal_state(context=None, input_state=2)

            # exploitation
            for agentID in range(nbPlayer):
                armID = self.agents[agentID].exploit(None, self.current_round)
                pulls[agentID][armID] = 1

            collisions = self.resolve_collision(pulls)
            current_rewards = self.observe_distributed_payoff(game_env, collisions)

        # update round number
        self.current_round = self.current_round + 1

        total_rewards = np.sum(current_rewards)
        return pulls, total_rewards, current_rewards
