import torch
import torch.nn.functional as F
from Network import PolicyNet, ValueNet
from util import compute_advantage


class ComaCritic(torch.nn.Module):
    """COMA centralized critic: Q(s, a_{-i}, a_i') for each agent."""

    def __init__(self, input_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class COMA:
    def __init__(self, agent_num, state_dim_list, hidden_dim, action_num_list,
                 actor_lr, critic_lr, epochs, eps, gamma,
                 device, sample_size, entropy_soft=False):
        self.N = agent_num
        self.A_dims = action_num_list
        self.device = device
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.entropy = entropy_soft
        self.entropy_c = 0.01
        self.gae_lmbda = 0.95

        self.actors = [
            PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
            for i in range(agent_num)
        ]
        self.actor_opts = [
            torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors
        ]

        self.gs_dim = sum(state_dim_list)
        self.ja_dim = sum(action_num_list)
        self.id_dim = agent_num
        critic_input_dim = self.gs_dim + self.ja_dim + self.id_dim

        self.critics = [
            ComaCritic(critic_input_dim, hidden_dim, action_num_list[i]).to(device)
            for i in range(agent_num)
        ]
        self.critic_opts = [
            torch.optim.AdamW(critic.parameters(), lr=critic_lr) for critic in self.critics
        ]

    @torch.no_grad()
    def take_action(self, state_list, eval: bool = False):
        actions = []
        for i in range(self.N):
            probs = self.actors[i](state_list[i].to(self.device))
            dist = torch.distributions.Categorical(probs)
            a = torch.argmax(probs, dim=-1) if eval else dist.sample()
            actions.append(a.cpu())
        return actions

    def update(self, trans):
        s = trans['states'].to(self.device)
        a = trans['actions'].long()
        if a.dim() == 3:
            a = a.squeeze(-1)
        a = a.to(self.device)
        r = trans['rewards'].to(self.device).view(-1, 1)
        d = trans['dones'].float().to(self.device).view(-1, 1)
        ns = trans['next_states'].to(self.device)

        T = s.shape[0]
        gs = s.view(T, -1)
        ngs = ns.view(T, -1)

        ja_oh = self._joint_action_onehot(a)
        na = torch.roll(a, shifts=-1, dims=0)
        ja_oh_next = self._joint_action_onehot(na)

        crt_in = self._build_critic_inputs(gs, ja_oh)
        crt_in_next = self._build_critic_inputs(ngs, ja_oh_next)

        for i in range(self.N):
            q_vals = self.critics[i](crt_in[i])
            q_next = self.critics[i](crt_in_next[i]).detach()
            q_taken = q_vals.gather(1, a[:, i:i+1])
            q_next_taken = q_next.gather(1, na[:, i:i+1])

            td_target = r + self.gamma * q_next_taken * (1 - d)
            critic_loss = F.mse_loss(q_taken, td_target)

            self.critic_opts[i].zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critics[i].parameters(), 40.0)
            self.critic_opts[i].step()

            probs = self.actors[i](s[:, i, :])
            dist = torch.distributions.Categorical(probs)
            logp_taken = torch.log(probs.gather(1, a[:, i:i+1]))

            with torch.no_grad():
                q_vals_det = q_vals.detach()
                baseline = (probs * q_vals_det).sum(dim=1, keepdim=True)
                advantage = q_taken.detach() - baseline

            actor_loss = -(logp_taken * advantage).mean()

            if self.entropy:
                entropy = -(probs * probs.log()).sum(-1).mean()
                actor_loss -= self.entropy_c * entropy

            self.actor_opts[i].zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), 40.0)
            self.actor_opts[i].step()

    def _joint_action_onehot(self, actions):
        if actions.dim() == 3:
            actions = actions.squeeze(-1)

        T = actions.size(0)
        ja = torch.zeros(T, self.ja_dim, device=self.device)

        offset = 0
        for i, adim in enumerate(self.A_dims):
            act_i = actions[:, i]
            if act_i.dim() > 1:
                act_i = act_i.squeeze(-1)
            oh = F.one_hot(act_i, num_classes=adim).float()
            ja[:, offset:offset + adim] = oh
            offset += adim

        return ja

    def _build_critic_inputs(self, gstates, ja_onehot):
        T = gstates.shape[0]
        agent_ids = torch.eye(self.N, device=self.device)
        inputs = []
        for i in range(self.N):
            id_i = agent_ids[i].unsqueeze(0).expand(T, -1)
            x = torch.cat([gstates, ja_onehot, id_i], dim=1)
            inputs.append(x)
        return inputs
