import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import os

from copy import deepcopy


def moving_average(interval, windowsize):

    window = np.ones(int(windowsize)) / float(windowsize)
    re = np.convolve(interval, window, 'same')
    return re


def onehot_from_logits(logits, eps=0.0):
    """
    Given batch of logits, return one-hot sample using epsilon greedy strategy
    (based on given epsilon)
    """
    # get best (according to current policy) actions in one-hot form
    argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()
    if eps == 0.0:
        return argmax_acs
    # get random actions in one-hot form
    rand_acs = Variable(
        torch.eye(logits.shape[1])[
            [np.random.choice(range(logits.shape[1]), size=logits.shape[0])]
        ],
        requires_grad=False,
    )
    # chooses between best and random actions using epsilon greedy
    return torch.stack(
        [
            argmax_acs[i] if r > eps else rand_acs[i]
            for i, r in enumerate(torch.rand(logits.shape[0]))
        ]
    )


class Actor(nn.Module):
    def __init__(
        self,
    ):
        super(Actor, self).__init__()
        self.logits = torch.Tensor(np.array([5, 4, 4]))
        self.turn_off_grad()
        self.epsilon = 0.2
        self.actor = torch.distributions.Categorical(logits=self.logits)
        print(self.actor.probs)

    def sample(self, bs):
        return self.actor.sample(bs)

    def entropy(self):
        return self.actor.entropy()

    def probs(self):
        return self.actor.probs

    def log_prob(self, action):
        return self.actor.log_prob(action)

    def update(self):
        self.actor = torch.distributions.Categorical(logits=self.logits)

    def epsilon_greedy(self, bs):
        actions = []
        for _ in range(bs[0]):
            if np.random.random() < self.epsilon:
                action = torch.randint(
                    low=0, high=3, size=(1, )
                )[0]
            else:
                action = self.actor.probs.argmax(dim=-1, keepdim=True)
            actions.append(action)
        return torch.tensor(actions)

    def turn_on_grad(self):
        """Turn on grad for actor parameters."""
        self.logits.requires_grad_(True)

    def turn_off_grad(self):
        """Turn off grad for actor parameters."""
        self.logits.requires_grad_(False)

class ReplayBuffer:

    def __init__(self, maxsize, n_agents):
        self.__maxsize = maxsize
        self.__pointer = 0
        self.__actions = torch.zeros((maxsize, n_agents), dtype=torch.float32)
        self.__rewards = torch.zeros((maxsize, 1), dtype=torch.float32)

        self.size = 0

    def put(self, action, reward):
        bs = action.shape[0]
        self.__actions[self.__pointer:self.__pointer + bs] = action
        self.__rewards[self.__pointer:self.__pointer + bs] = reward
        self.__pointer = (self.__pointer + bs) % self.__maxsize
        self.size = min(self.size + bs, self.__maxsize)

    def get(self, bs):
        idx = np.random.choice(self.size, bs, replace=False)
        return self.__actions[idx], self.__rewards[idx]


def main(n_agents, from_scrath=True, algos=[], eps_greedy=False):
    assert len(algos) > 0

    seeds = [1, 2, 3, 4, 5, 6]
    n_epoches = int(2e3) if not eps_greedy else int(4e3)
    losses = torch.zeros(len(seeds), n_epoches)
    smooth_window_size = 10
    bs = 128
    lr = 0.1
    ppo_epoch = 5

    if from_scrath:
        payoff = torch.zeros(9, dtype=torch.float32)
        payoff[0] = 5
        payoff[1] = payoff[2] = payoff[3] = payoff[5] = payoff[6] = payoff[7] = -20
        payoff[4] = 10
        payoff[8] = 20
        
        if 'hasac' in algos:
            for j, seed in enumerate(seeds):
                dir_path = os.path.join("hasac", str(seed))
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                log_file = open(
                    os.path.join(dir_path, "progress.txt"), "w", encoding="utf-8"
                )
                torch.manual_seed(seed)
                buffer = ReplayBuffer(int(1e6), n_agents)

                centralized_q = torch.randn((3, 3), requires_grad=True)
                local_actors = []
                for agent in range(n_agents):
                    local_actor = Actor()
                    local_actors.append(local_actor)

                critic_optimizer = torch.optim.SGD([centralized_q], lr=lr)
                actor_optimizer = torch.optim.SGD([local_actors[0].logits, local_actors[1].logits], lr=lr)

                for epoch in range(1, n_epoches + 1):

                    with torch.no_grad():
                        actions = torch.zeros((bs, n_agents), dtype=torch.int64)
                        for agent in range(n_agents):
                            actions[:, agent] = local_actors[agent].sample([bs])

                        rewards = torch.zeros((bs, 1), dtype=torch.float32)
                        for i in range(bs):
                            rewards[i] = payoff[actions[i][0] * 3 + actions[i][1]]
                        buffer.put(actions, rewards)

                    actions, rewards = buffer.get(bs)
                    centralized_q.requires_grad_(True)
                    actions = actions.long()
                    q = centralized_q[actions[:, 0]].gather(-1, actions[:, 1].unsqueeze(-1))

                    loss = (rewards.flatten() - q.flatten()).square().mean()
                    critic_optimizer.zero_grad()
                    loss.backward()
                    critic_optimizer.step()
                    centralized_q.requires_grad_(False)

                    for agent in range(n_agents):
                        local_actors[agent].turn_on_grad()
                        local_actors[agent].update()
                        joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                                   local_actors[0].probs()[1] * local_actors[1].probs(),
                                                   local_actors[0].probs()[2] * local_actors[1].probs(),))
                        value_pred = torch.sum(joint_probs * centralized_q)
                        actor_loss = - (value_pred + 0 * local_actors[agent].entropy())
                        actor_optimizer.zero_grad()
                        actor_loss.backward()
                        actor_optimizer.step()
                        local_actors[agent].turn_off_grad()
                        local_actors[agent].update()

                    joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                                local_actors[0].probs()[1] * local_actors[1].probs(),
                                                local_actors[0].probs()[2] * local_actors[1].probs(),))
                    print("Algorithm: HASAC, Epoch: {}, Joint policy: {}".format(epoch, joint_probs.detach().flatten()))
                    aver_episode_rewards = torch.sum(payoff * joint_probs.flatten())
                    log_file.write(
                        ",".join(map(str, [epoch, aver_episode_rewards.item()])) + "\n"
                    )
                    log_file.flush()

                joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                            local_actors[0].probs()[1] * local_actors[1].probs(),
                                            local_actors[0].probs()[2] * local_actors[1].probs(),))
                print("Algorithm: HASAC, Seed: {}, Joint policy: {}".format(seed, joint_probs.detach().flatten()))

            with open('hasac_perm_loss.npy', 'wb') as f:
                np.save(f, losses.numpy())

        if 'happo' in algos:
            for j, seed in enumerate(seeds):
                dir_path = os.path.join("happo", str(seed))
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                log_file = open(
                    os.path.join(dir_path, "progress.txt"), "w", encoding="utf-8"
                )
                torch.manual_seed(seed)

                # centralized_q = torch.randn((3, 3), requires_grad=True)
                local_actors = []
                for agent in range(n_agents):
                    local_actor = Actor()
                    local_actors.append(local_actor)

                # critic_optimizer = torch.optim.SGD([centralized_q], lr=lr)
                actor_optimizer = torch.optim.SGD([local_actors[0].logits, local_actors[1].logits], lr=lr)

                for epoch in range(1, n_epoches + 1):

                    with torch.no_grad():
                        actions = torch.zeros((1, n_agents), dtype=torch.int64)
                        old_action_log_probs = torch.zeros((1, n_agents), dtype=torch.float32)
                        for agent in range(n_agents):
                            actions[:, agent] = local_actors[agent].sample([1])
                            old_action_log_probs[:, agent] = local_actors[agent].log_prob(actions[:, agent])

                        rewards = torch.zeros((1, 1), dtype=torch.float32)
                        for i in range(1):
                            rewards[i] = payoff[actions[i][0] * 3 + actions[i][1]]

                    factor = torch.ones(1, dtype=torch.float32)
                    agent_order = list(torch.randperm(n_agents).numpy())
                    for agent in agent_order:
                        for _ in range(ppo_epoch):
                            local_actors[agent].turn_on_grad()
                            local_actors[agent].update()
                            new_action_log_prob = local_actors[agent].log_prob(actions[0][agent])
                            imp_weights = getattr(torch, 'prod')(
                                torch.exp(new_action_log_prob - old_action_log_probs[0][agent])
                            )
                            surr1 = imp_weights * rewards[0]
                            surr2 = (
                                torch.clamp(imp_weights, 0.8, 1.2) * rewards[0]
                            )
                            policy_action_loss = factor * torch.min(surr1, surr2)

                            actor_loss = - (policy_action_loss + 5 * local_actors[agent].entropy())
                            actor_optimizer.zero_grad()
                            actor_loss.backward()
                            actor_optimizer.step()
                            local_actors[agent].turn_off_grad()
                            local_actors[agent].update()

                        new_action_logprob = local_actors[agent].log_prob(actions[0][agent])
                        factor = factor * getattr(torch, 'prod')(
                                torch.exp(new_action_logprob - old_action_log_probs[0][agent])
                        )

                    joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                                local_actors[0].probs()[1] * local_actors[1].probs(),
                                                local_actors[0].probs()[2] * local_actors[1].probs(),))
                    print("Algorithm: HAPPO, Epoch: {}, Joint policy: {}".format(epoch, joint_probs.detach().flatten()))
                    aver_episode_rewards = torch.sum(payoff * joint_probs.flatten())
                    log_file.write(
                        ",".join(map(str, [epoch, aver_episode_rewards.item()])) + "\n"
                    )
                    log_file.flush()

                joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                            local_actors[0].probs()[1] * local_actors[1].probs(),
                                            local_actors[0].probs()[2] * local_actors[1].probs(),))
                print("Algorithm: HAPPO, Seed: {}, Joint policy: {}".format(seed, joint_probs.detach().flatten()))

            with open('happo_perm_loss.npy', 'wb') as f:
                np.save(f, losses.numpy())

        if 'mappo' in algos:
            for j, seed in enumerate(seeds):
                dir_path = os.path.join("mappo", str(seed))
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                log_file = open(
                    os.path.join(dir_path, "progress.txt"), "w", encoding="utf-8"
                )
                torch.manual_seed(seed)

                # centralized_q = torch.randn((3, 3), requires_grad=True)
                local_actors = []
                for agent in range(n_agents):
                    local_actor = Actor()
                    local_actors.append(local_actor)

                # critic_optimizer = torch.optim.SGD([centralized_q], lr=lr)
                actor_optimizer = torch.optim.SGD([local_actors[0].logits, local_actors[1].logits], lr=lr)

                for epoch in range(1, n_epoches + 1):

                    with torch.no_grad():
                        actions = torch.zeros((1, n_agents), dtype=torch.int64)
                        old_action_log_probs = torch.zeros((1, n_agents), dtype=torch.float32)
                        for agent in range(n_agents):
                            actions[:, agent] = local_actors[agent].sample([1])
                            old_action_log_probs[:, agent] = local_actors[agent].log_prob(actions[:, agent])

                        rewards = torch.zeros((1, 1), dtype=torch.float32)
                        for i in range(1):
                            rewards[i] = payoff[actions[i][0] * 3 + actions[i][1]]

                    for agent in range(n_agents):
                        for _ in range(ppo_epoch):
                            local_actors[agent].turn_on_grad()
                            local_actors[agent].update()
                            new_action_log_prob = local_actors[agent].log_prob(actions[0][agent])
                            imp_weights = getattr(torch, 'prod')(
                                torch.exp(new_action_log_prob - old_action_log_probs[0][agent])
                            )
                            surr1 = imp_weights * rewards[0]
                            surr2 = (
                                torch.clamp(imp_weights, 0.8, 1.2) * rewards[0]
                            )
                            policy_action_loss = torch.min(surr1, surr2)

                            actor_loss = - (policy_action_loss + 5 * local_actors[agent].entropy())
                            actor_optimizer.zero_grad()
                            actor_loss.backward()
                            actor_optimizer.step()
                            local_actors[agent].turn_off_grad()
                            local_actors[agent].update()

                    joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                                local_actors[0].probs()[1] * local_actors[1].probs(),
                                                local_actors[0].probs()[2] * local_actors[1].probs(),))
                    print("Algorithm: MAPPO, Epoch: {}, Joint policy: {}".format(epoch, joint_probs.detach().flatten()))
                    aver_episode_rewards = torch.sum(payoff * joint_probs.flatten())
                    log_file.write(
                        ",".join(map(str, [epoch, aver_episode_rewards.item()])) + "\n"
                    )
                    log_file.flush()

                joint_probs = torch.stack((local_actors[0].probs()[0] * local_actors[1].probs(),
                                            local_actors[0].probs()[1] * local_actors[1].probs(),
                                            local_actors[0].probs()[2] * local_actors[1].probs(),))
                print("Algorithm: MAPPO, Seed: {}, Joint policy: {}".format(seed, joint_probs.detach().flatten()))

            with open('mappo_perm_loss.npy', 'wb') as f:
                np.save(f, losses.numpy())



if __name__ == '__main__':
    main(2, True, ['maddpg'], True)
