import numpy as np

if __name__ == '__main__':
    print("Warning: this script 'MABAlgorithms.py' is NOT executable..")  # DEBUG
    exit(0)


class MABAlgorithm(object):
    """ Base class for an algorithm class."""

    def __init__(self, param):
        """ Base class for an algorithm class."""
        self.nbPlayer = param["nbPlayer"]
        self.nbArm = param["nbArm"]
        self.horizon = param["horizon"]

        self.nbAgent = 0  # number of agents in the algorithms, can be centralized, decentralized or partially decentralized

        # an agent is usually corresponding to a player, it has its own
        self.agents = []

        delay_mean = 50
        delay_std = 30

    # --- Printing
    def __str__(self):
        return self.__class__.__name__

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, self.__dir__)

    # --- functionalities
    def resolve_collision(self, pulls):
        (nbPlayer, nbArm) = np.shape(pulls)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."
        assert nbPlayer <= nbArm, "player number should be larger than or equal to arm number."

        collisions = pulls.sum(axis=0)

        assert len(collisions) == nbArm, "dimension of collisions is incorrect"
        return collisions

    def learn_policy(self, game_env, context=None, time=None):
        """
        Learn policies based on the given game environments.
        A game environment can be in the form of (context, sampel_reward_matrix)
        """
        raise NotImplementedError(
            "This method learn_policy(t) has to be implemented in the class inheriting from MABAlgorithm.")

    def reset(self, horizon=None):
        """
        The rest parameters cannot be reset, except self.horizon.
        """
        raise NotImplementedError(
            "This method reset() has to be implemented in the class inheriting from MABAlgorithm.")

    def pulls2choices(self, pulls):
        """
        Convert pulls into choices
        """
        (nbPlayer, nbArm) = np.shape(pulls)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."

        arm_choices = np.zeros(nbPlayer, dtype=int)

        arm_selected = np.nonzero(pulls)  # index of non-zero values

        # for some algorithms there may be a case when a player refuse to choose any arm
        for index in range(len(arm_selected[0])):
            playerID = arm_selected[0][index]
            arm_choices[playerID] = arm_selected[1][index]  # playerID should be 0, 1, 2,..., nbPlayer-1

        return arm_choices

    def observe_distributed_payoff(self, game_env, collisions):
        (nbPlayer, nbArm) = np.shape(game_env)
        assert nbPlayer == self.nbPlayer and nbArm == self.nbArm, "input does not match the stored environment parameters."

        current_reward = np.zeros(self.nbPlayer)

        for playerID in range(self.nbPlayer):
            selected_arm = self.agents[playerID].selected_arm

            # for some algorithms there may be a case when a player refuses to choose any arm
            if selected_arm < 0:
                current_reward[playerID] = 0
            else:
                if collisions[selected_arm] == 1:
                    current_reward[playerID] = game_env[playerID][selected_arm]  # not collidiing
                else:
                    current_reward[playerID] = 0  # colliding or void

        # returen an array of dimension nbArm
        return current_reward
