import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt


class Agent():
    def __init__(self, agent_idx, game_obj, restrict_strictly_dominated_strategies=False):
        self.idx = agent_idx
        self.intrinsic_game_obj = deepcopy(game_obj)
        self.n_actions = self.intrinsic_game_obj.game_matrix_as_array.shape[self.idx]
        self.restrict_strictly_dominated_strategies = restrict_strictly_dominated_strategies
        self.agent_action_history = []
        self.all_players_action_history = []
        self.agent_perceived_total_payoff = 0
        self.agent_perceived_payoff_history = []
        self.all_players_perceived_payoff_history = []

    def __str__(self):
        return 'Agent (base class)'

    def get_agent_index(self):
        return self.idx

    def set_agent_game_obj(self, new_game_obj):
        self.intrinsic_game_obj = deepcopy(new_game_obj)
        self.n_actions = self.intrinsic_game_obj.game_matrix_as_array.shape[self.idx]

    def get_agent_game_obj(self):
        return self.intrinsic_game_obj

    def update_agent_action_history(self, new_action):
        self.agent_action_history.append(new_action)

    def get_agent_action_history(self):
        return self.agent_action_history

    def update_all_players_action_history(self, all_players_new_action):
        self.all_players_action_history.append(all_players_new_action)

    def get_all_players_action_history(self):
        return self.all_players_action_history

    def update_agent_perceived_total_payoff(self, new_payoff):
        self.agent_perceived_total_payoff += new_payoff

    def get_agent_perceived_total_payoff(self):
        return self.agent_perceived_total_payoff

    def update_agent_perceived_payoff_history(self, all_players_new_payoff):
        self.agent_perceived_payoff_history.append(all_players_new_payoff)

    def get_agent_perceived_payoff_history(self):
        return self.agent_perceived_payoff_history

    def update_action(self, action_profile):
        # updating the internal history that the agent records
        my_action = action_profile[self.idx]
        assert 0 <= my_action < self.n_actions
        actions, payoffs = self.intrinsic_game_obj.play_game_round(action_profile)
        self.agent_action_history.append(actions[self.idx])
        self.all_players_action_history.append(actions)
        self.agent_perceived_total_payoff += payoffs[self.idx]
        self.agent_perceived_payoff_history.append(payoffs[self.idx] )
        self.all_players_perceived_payoff_history.append(payoffs)

    def get_forgone_payoff_vector_from_game(self, steps_back=1):
        # return a vector of all gained and foregone payoffs for the last period of play

        recent_actions = self.all_players_action_history[-steps_back]
        last_payoff_vector = np.zeros(self.n_actions)
        for a in range(self.n_actions):
            actions_a = deepcopy(recent_actions)
            actions_a[self.idx] = a

            _, all_players_payoffs_a = self.intrinsic_game_obj.simulate_game_round(actions_a)
            my_payoff_a = all_players_payoffs_a[self.idx]

            last_payoff_vector[a] = my_payoff_a

        return last_payoff_vector


class MultiplicativeWeightsAgent(Agent):
    def __init__(self, agent_idx, game_obj, restrict_strictly_dominated_strategies=False, eta=0.01, initial_weights=None):
        super().__init__(agent_idx, game_obj, restrict_strictly_dominated_strategies=False)
        self.eta = eta
        if initial_weights is None:
            self.weights = self.reset_weights()
        else:
            self.weights = self.set_weights_by_input(initial_weights)
        self.strategy_dist_history = [self.get_action_dist()]

    def __str__(self):
        return 'MultiplicativeWeightsAgent'

    def reset_weights(self):
        self.weights = [1 for _ in range(self.n_actions)]
        return self.weights

    def set_weights_by_input(self, w):
        assert len(w) == self.n_actions
        self.weights = w
        return self.weights

    def get_weights(self):
        return self.weights

    def set_agent_game_obj(self, new_game_obj):
        super(MultiplicativeWeightsAgent, self).set_agent_game_obj(new_game_obj)
        self.reset_weights()

    def get_action_dist(self):
        z = np.sum(self.weights)
        p = [w/z for w in self.weights]
        return p

    def update_weights(self, payoff_vector):
        # This is the multiplicative-weights algorithm [Arora. et. al 2012]
        assert len(payoff_vector) == self.n_actions

        w = np.array(self.get_weights())
        payoff_vector = np.array(payoff_vector)
        w = w * (1 + self.eta*payoff_vector)
        # normalizing the weights
        w = w / np.sum(w)
        self.weights = w.tolist()

    def draw_action_from_action_dist(self):
        pdf = self.get_action_dist()
        return np.random.choice(list(range(self.n_actions)), size=1, p=pdf)[0]

    def update_action(self, action_profile):
        super(MultiplicativeWeightsAgent, self).update_action(action_profile)
        payoff_vect = self.get_forgone_payoff_vector_from_game()
        self.update_weights(payoff_vect)
        self.strategy_dist_history.append(self.get_action_dist())

    def visualize_action_dist(self):
        plt.figure()
        plt.plot(self.strategy_dist_history)
        plt.title('agent ' + str(self.idx) + ' strategy (action probabilities)')
        plt.xlabel('t')
        plt.ylabel('Pr(a)')
        plt.legend(['a = H','a = D'])
        plt.show()


class FTPLAgent(Agent):
    def __init__(self, agent_idx, game_obj, history_weight=1, eta=0.25):

        super().__init__(agent_idx, game_obj, restrict_strictly_dominated_strategies=False)
        self.history_weight = history_weight
        self.t = 0
        self.eta = eta
        self.strategy_dist_history = []
        self.total_historical_payoffs = 0

    def __str__(self):
        return 'FTPLAgent'

    def set_agent_game_obj(self, new_game_obj):
        super(FTPLAgent, self).set_agent_game_obj(new_game_obj)
        self.total_historical_payoffs = 0
        self.t = 0

    def update_avg_forgone_payoff(self, payoff_vect):

        history_term = self.total_historical_payoffs
        new_payoff_term = payoff_vect
        if self.eta <= 0:
            perturbation_term = 0
        else:
            perturbation_term = np.random.gumbel(0,self.eta, size=payoff_vect.shape)

        new_total_payoff = history_term + new_payoff_term + perturbation_term

        self.total_historical_payoffs = new_total_payoff

    def draw_action_from_action_dist(self):
        FTPL_best_action = np.argmax(self.total_historical_payoffs)
        return FTPL_best_action

    def update_action(self, action_profile):
        super(FTPLAgent, self).update_action(action_profile)
        payoff_vect = self.get_forgone_payoff_vector_from_game()
        self.update_avg_forgone_payoff(payoff_vect)
        self.t += 1
        p = 1 - action_profile[0]
        q = 1 - action_profile[1]
        if self.idx == 0:
            action_dist = [p,1-p]
        elif self.idx == 1:
            action_dist = [q,1-q]
        self.strategy_dist_history.append(action_dist)

    def visualize_action_dist(self):
        plt.figure()
        plt.plot(self.strategy_dist_history)
        plt.title('agent ' + str(self.idx) + ' strategy (action probabilities)')
        plt.xlabel('t')
        plt.ylabel('Pr(a)')
        plt.legend(['a = H','a = D'])
        plt.show()


######################################################################
def run_T_rounds(game_input, agents_input, T_rounds):
    for t in range(T_rounds):
        action_profile = [a.draw_action_from_action_dist() for a in agents_input]
        game_input.play_game_round(action_profile)
        for a in agents_input:
            a.update_action(action_profile)
    return game_input, agents_input
