import torch
import torch.nn.functional as F
import numpy as np


from VD_Network import PolicyNetContinuous, LocalValueNet, VDNMixer
# ---------------- GAE ---------------- #
def compute_advantage(gamma, lam, td_delta, dones):
    """
    td_delta : [T,1]   (已detach)
    dones    : [T,1] or [T]
    """
    td_delta = td_delta.detach()
    if dones.dim() == 2:     # [T,1]  -> squeeze 成 [T]
        dones = dones.squeeze(-1)

    T = len(td_delta)
    adv = torch.zeros_like(td_delta)
    gae = 0.0
    for t in reversed(range(T)):
        if dones[t] == 1.0:
            gae = 0.0
        gae = td_delta[t, 0] + gamma * lam * gae
        adv[t, 0] = gae
    return adv



class VDPPOContinuous:
    def __init__(self,
                 agent_num,
                 obs_dim, hidden_dim,
                 act_dim, actor_lr,
                 critic_lr, lam, epochs,
                 eps, gamma,
                 mixer_hidden=32,
                 bound=1.0,
                 device='cpu'):
        self.n = agent_num
        self.device = device
        self.actors = [PolicyNetContinuous(obs_dim, hidden_dim, act_dim,
                                           bound).to(device) for _ in range(self.n)]
        self.actor_opt = [torch.optim.Adam(a.parameters(), lr=actor_lr, weight_decay=1e-5)
                          for a in self.actors]
        self.v_nets = [LocalValueNet(obs_dim, hidden_dim).to(device)
                       for _ in range(self.n)]
        self.mixer = VDNMixer().to(device)            # or MLPMixer(self.n, mixer_hidden)
        mix_params = list(self.mixer.parameters()) + \
                     [p for net in self.v_nets for p in net.parameters()]
        self.v_opt = torch.optim.Adam(mix_params, lr=critic_lr, weight_decay=1e-5)

        self.gamma, self.lam, self.epochs, self.eps = gamma, lam, epochs, eps
        self.act_dim = act_dim

    def take_action(self, obs, eval=False):
        """
        obs: ndarray [n_agents, obs_dim]
        """
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device)

        with torch.no_grad() if eval else torch.enable_grad():
            mu = torch.zeros(self.n, self.act_dim, device=self.device)
            sig = torch.zeros_like(mu)
            for i in range(self.n):
                mu_i, sig_i = self.actors[i](obs_t[i])
                mu[i], sig[i] = mu_i, sig_i

        if eval:
            return mu.cpu().numpy()                # mu
        dist   = torch.distributions.Normal(mu, sig)
        return dist.sample().cpu().numpy()

    def take_action_async(self, obs_batch, eval=False):
        """
        obs_batch: ndarray [env_num, n_agents, obs_dim]
        """
        obs_t = torch.as_tensor(obs_batch, dtype=torch.float32, device=self.device)

        with torch.no_grad() if eval else torch.enable_grad():
            env_n = obs_t.shape[0]
            mu    = torch.zeros(env_n, self.n, self.act_dim, device=self.device)
            sig   = torch.zeros_like(mu)
            for i in range(self.n):
                mu_i, sig_i   = self.actors[i](obs_t[:, i, :])   # [env_n, act_dim]
                mu[:, i, :]   = mu_i
                sig[:, i, :]  = sig_i

        if eval:
            return mu.cpu().numpy()
        dist   = torch.distributions.Normal(mu, sig)
        return dist.sample().cpu().numpy()


    def update(self, buf):
        S  = buf['states'].to(self.device)        # [T, n, obs_dim]
        A  = buf['actions'].to(self.device)       # [T, n, act_dim]
        R = buf['rewards'].to(self.device)
        D = buf['dones'].float().to(self.device)
        S2 = buf['next_states'].to(self.device)

        if R.dim() == 1:  R = R.unsqueeze(1)  # [T] -> [T,1]
        if D.dim() == 1:  D = D.unsqueeze(1)

        T = S.shape[0]

        V_local  = torch.stack([self.v_nets[i](S[:, i, :]) for i in range(self.n)], dim=1)   # [T, n, 1]
        V_local2 = torch.stack([self.v_nets[i](S2[:, i, :]) for i in range(self.n)], dim=1)  # [T, n, 1]

        V_tot  = self.mixer(V_local)              # [T, 1]
        V_tot2 = self.mixer(V_local2).detach()    # [T, 1]  -- bootstrap

        td_target = R + self.gamma * V_tot2 * (1 - D)
        td_delta  = td_target - V_tot

        adv = compute_advantage(self.gamma, self.lam, td_delta, D).to(self.device)

        for i in range(self.n):
            s_i = S[:, i, :]                      # [T, obs_dim]
            a_i = A[:, i, :]                      # [T, act_dim]

            # old pi
            with torch.no_grad():
                mu_old, sig_old = self.actors[i](s_i)
                old_dist = torch.distributions.Normal(mu_old, sig_old)
                old_logp = old_dist.log_prob(a_i).sum(-1, keepdim=True)  # [T,1]

            for _ in range(self.epochs):
                mu, sig = self.actors[i](s_i)
                dist = torch.distributions.Normal(mu, sig)
                logp = dist.log_prob(a_i).sum(-1, keepdim=True)

                ratio = torch.exp(logp - old_logp)
                surr1 = ratio * adv
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * adv
                loss_pi = -torch.min(surr1, surr2).mean() - 0.001 * dist.entropy().mean()

                self.actor_opt[i].zero_grad()
                loss_pi.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), 40.0)
                self.actor_opt[i].step()

        for _ in range(self.epochs):
            V_local  = torch.stack([self.v_nets[i](S[:, i, :]) for i in range(self.n)], dim=1)
            V_tot    = self.mixer(V_local)                          # [T,1]
            loss_v   = F.mse_loss(V_tot, td_target.detach())

            self.v_opt.zero_grad()
            loss_v.backward()
            torch.nn.utils.clip_grad_norm_(self.mixer.parameters(), 40.0)
            for net in self.v_nets:
                torch.nn.utils.clip_grad_norm_(net.parameters(), 40.0)
            self.v_opt.step()
