import torch.nn as nn
import os
import torch
import numpy as np

from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Data, Batch
from torch_geometric.nn import Sequential
from torch import Tensor
from torch.optim import Adam
from typing import List, Tuple, Optional

from gcbf.nn import MLP, CBFGNNLayer
from gcbf.controller import GNNController
from gcbf.env import MultiAgentEnv
from gcbf.controller.mappo_controller import StateIndependentPolicy, MAPPOCritic

from .base import Algorithm
from .buffer import RolloutBuffer


def calculate_gae(
        values: torch.Tensor,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        next_values: torch.Tensor,
        gamma: float,
        lambd: float
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Calculate generalized advantage estimator.

    Parameters
    ----------
    values: torch.Tensor,
        values of the states
    rewards: torch.Tensor,
        rewards given by the reward function
    dones: torch.Tensor,
        if this state is the end of the episode
    next_values: torch.Tensor,
        values of the next states
    gamma: float,
        discount factor
    lambd: float,
        lambd factor

    Returns
    -------
    advantages: torch.Tensor,
        advantages
    gaes: torch.Tensor,
        normalized gae
    """
    # calculate TD errors
    dones = dones.repeat(1, values.shape[1])
    deltas = rewards + gamma * next_values * (1 - dones) - values
    # initialize gae
    gaes = torch.empty_like(rewards)

    # calculate gae recursively from behind
    gaes[-1] = deltas[-1]
    for t in reversed(range(rewards.shape[0] - 1)):
        gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]

    return gaes + values, (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)


class MAPPO(Algorithm):
    """
    Multi-agent PPO

    References
    ----------
    [1] Chao Yu, Akash Velu, Eugene Vinitsky, Yu Wang, Alexandre Bayen, and Yi Wu.
    The surprising effectiveness of ppo in cooperative, multi-agent games.
    arXiv preprint arXiv:2103.01955, 2021.
    """

    def __init__(
            self,
            env: MultiAgentEnv,
            num_agents: int,
            node_dim: int,
            edge_dim: int,
            action_dim: int,
            device: torch.device,
            gamma: float = 0.995,
            rollout_length: int = 2048,
            lr_actor: float = 3e-4,
            lr_critic: float = 3e-4,
            units_critic: tuple = (256, 128),
            epoch_ppo: int = 20,
            clip_eps: float = 0.2,
            lambd: float = 0.97,
            coef_ent: float = 0.0,
            max_grad_norm: float = 10.0
    ):
        super(MAPPO, self).__init__(
            env=env,
            num_agents=num_agents,
            node_dim=node_dim,
            edge_dim=edge_dim,
            action_dim=action_dim,
            device=device
        )

        # rollout buffer
        self.buffer = RolloutBuffer(
            num_agents=num_agents,
            buffer_size=rollout_length,
            action_dim=action_dim,
            device=device
        )

        # actor
        self.actor = StateIndependentPolicy(
            node_dim=node_dim,
            edge_dim=edge_dim,
            action_dim=action_dim,
            phi_dim=256
        ).to(device)
        self.optim_actor = Adam(self.actor.parameters(), lr=lr_actor)

        # critic
        # self.critic = MLP(
        #     in_channels=env.state_dim*num_agents,
        #     out_channels=1,
        #     hidden_layers=units_critic,
        #     hidden_activation=nn.Tanh()
        # ).to(device)
        self.critic = MAPPOCritic(
            node_dim=node_dim,
            edge_dim=edge_dim,
            action_dim=action_dim,
            phi_dim=256
        ).to(device)
        self.optim_critic = Adam(self.critic.parameters(), lr=lr_critic)

        self.learning_steps_ppo = 0
        self.learning_steps = 0
        self.rollout_length = rollout_length
        self.epoch_ppo = epoch_ppo
        self.clip_eps = clip_eps
        self.lambd = lambd
        self.coef_ent = coef_ent
        self.max_grad_norm = max_grad_norm
        self.gamma = gamma

        self._log_pi = None

    @torch.no_grad()
    def act(self, data: Data) -> Tensor:
        return self.actor(data)

    @torch.no_grad()
    def step(self, data: Data, prob: float) -> Tensor:
        action, log_pi = self.actor.sample(data)
        self._log_pi = log_pi  # (n, 1)
        return action  # (n, action_dim)

    def post_step(self, data: Data, action: Tensor, reward: float, done: bool, next_data: Data):
        self.buffer.append(data, action, reward, done, self._log_pi.squeeze(), next_data)

    def is_update(self, step: int) -> bool:
        return step % self.rollout_length == 0 and step >= self.rollout_length

    def update(self, step: int, writer: SummaryWriter = None) -> dict:
        self.learning_steps += 1
        data, actions, rewards, dones, log_pis, next_data = \
            self.buffer.get()
        self.update_ppo(data, actions, rewards, dones, log_pis, next_data)
        return {}

    def update_ppo(
            self,
            data: List[Data],
            actions: torch.Tensor,
            rewards: torch.Tensor,
            dones: torch.Tensor,
            log_pis: torch.Tensor,
            next_data: List[Data]
    ):
        """
        Update PPO's actor and critic for some steps.
        """
        with torch.no_grad():
            data = Batch.from_data_list(data)
            # states = []
            # for d in data:
            #     states.append(d.states.view(1, -1))
            # states = torch.cat(states, dim=0)
            values = self.critic(data).view(-1, self.num_agents)
            # values = self.critic(data.states.view(actions.shape[0], -1))

            # next_states = []
            # for d in next_data:
            #     next_states.append(d.states.view(1, -1))
            # next_states = torch.cat(next_states, dim=0)
            next_data = Batch.from_data_list(next_data)
            # next_values = self.critic(next_data.states.view(actions.shape[0], -1))
            next_values = self.critic(next_data).view(-1, self.num_agents)

        targets, gaes = calculate_gae(
            values, rewards, dones, next_values, self.gamma, self.lambd)

        for _ in range(self.epoch_ppo):
            self.learning_steps_ppo += 1
            self.update_critic(data, targets)
            self.update_actor(data, actions, log_pis, gaes)

    def update_critic(self, data: Data, targets: torch.Tensor):
        """
        Update the critic for one step.
        """
        # states = []
        # for d in data:
        #     states.append(d.states.view(1, -1))
        # states = torch.cat(states, dim=0)
        # data = Batch.from_data_list(data)
        values = self.critic(data).view(-1, self.num_agents)
        loss_critic = (values - targets).pow_(2).mean()

        self.optim_critic.zero_grad(set_to_none=True)
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.optim_critic.step()

    def update_actor(
            self,
            data: Data,
            actions: torch.Tensor,  # (bs, num_agents, action_dim)
            log_pis_old: torch.Tensor,
            gaes: torch.Tensor
    ):
        """
        Update the actor for one step.
        """
        # for i, d in enumerate(data):
        #     log_pis.append(self.actor.evaluate_log_pi(d, actions[i]))
        # data = Batch.from_data_list(data)
        log_pis = self.actor.evaluate_log_pi(data, actions, num_agents=self.num_agents)
        # log_pis = log_pis.sum(dim=1, keepdim=True)
        entropy = -log_pis.mean()
        ratios = (log_pis - log_pis_old).exp_()
        loss_actor1 = -ratios * gaes
        loss_actor2 = -torch.clamp(
            ratios,
            1.0 - self.clip_eps,
            1.0 + self.clip_eps
        ) * gaes
        loss_actor = torch.max(loss_actor1, loss_actor2).mean()
        self.optim_actor.zero_grad(set_to_none=True)
        (loss_actor - self.coef_ent * entropy).backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        self.optim_actor.step()

        # system_states = torch.split(states, list(self.state_dims), dim=1)
        # system_actions = torch.split(actions, list(self.action_dims), dim=1)
        #
        # log_pis = []
        # for i in range(self.n_systems):
        #     log_pis.append(self.actors[self.mapping_ctrl(i)].evaluate_log_pi(system_states[i], system_actions[i]))
        # log_pis = torch.sum(torch.cat(log_pis, dim=1), dim=1).unsqueeze(1)

        # entropy = -log_pis.mean()
        #
        # ratios = (log_pis - log_pis_old).exp_()
        # loss_actor1 = -ratios * gaes
        # loss_actor2 = -torch.clamp(
        #     ratios,
        #     1.0 - self.clip_eps,
        #     1.0 + self.clip_eps
        # ) * gaes
        # loss_actor = torch.max(loss_actor1, loss_actor2).mean()
        #
        # for i in range(self.n_actor):
        #     self.optim_actor[i].zero_grad()
        # (loss_actor - self.coef_ent * entropy).backward(retain_graph=False)
        # for i in range(self.n_actor):
        #     nn.utils.clip_grad_norm_(self.actors[i].parameters(), self.max_grad_norm)
        #     self.optim_actor[i].step()

    def save(self, save_dir: str):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(self.actor.state_dict(), os.path.join(save_dir, 'actor.pkl'))
        torch.save(self.critic.state_dict(), os.path.join(save_dir, 'critic.pkl'))

    def load(self, load_dir: str):
        assert os.path.exists(load_dir)
        self.actor.load_state_dict(torch.load(os.path.join(load_dir, 'actor.pkl')))
        self.critic.load_state_dict(torch.load(os.path.join(load_dir, 'critic.pkl')))

    def apply(self, data: Data, rand: Optional[float] = 30) -> Tensor:
        return self.act(data)
