from __future__ import annotations

from dataclasses import dataclass
import copy
import torch
import torch.nn.functional as F

from .sac_agent import SACAgent, SACConfig
from .networks import QNetwork


@dataclass
class LagrangianSACConfig(SACConfig):
    cost_limit: float = 0.02
    lambda_lr: float = 0.05
    use_kl: bool = False
    kl_target: float = 0.01
    kl_lr: float = 0.05


def gaussian_kl(mu_p, logstd_p, mu_q, logstd_q):
                                                     
    var_p = torch.exp(2 * logstd_p)
    var_q = torch.exp(2 * logstd_q)
    return 0.5 * (
        ( (var_p + (mu_p - mu_q) ** 2) / (var_q + 1e-12) ).sum(dim=-1)
        + 2 * (logstd_q - logstd_p).sum(dim=-1)
        - mu_p.shape[-1]
    )


class LagrangianSACAgent(SACAgent):
    def __init__(self, cfg: LagrangianSACConfig, device: torch.device):
        super().__init__(cfg, device)
        self.lcfg: LagrangianSACConfig = cfg
                      
        self.qc1 = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.qc2 = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.qc1_targ = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.qc2_targ = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.qc1_targ.load_state_dict(self.qc1.state_dict())
        self.qc2_targ.load_state_dict(self.qc2.state_dict())
        self.qc1_opt = torch.optim.Adam(self.qc1.parameters(), lr=cfg.critic_lr)
        self.qc2_opt = torch.optim.Adam(self.qc2.parameters(), lr=cfg.critic_lr)

                        
        self.lambda_c = torch.tensor(0.0, device=device)
        self.eta_kl = torch.tensor(1.0, device=device)                                  
        self.actor_old = copy.deepcopy(self.actor).eval().requires_grad_(False)

    def _update_cost_critics(self, obs, act, cost, next_obs, done):
        with torch.no_grad():
            next_act, _ = self.actor.sample(next_obs)
            qc1_next = self.qc1_targ(next_obs, next_act)
            qc2_next = self.qc2_targ(next_obs, next_act)
            qc_next = torch.min(qc1_next, qc2_next)
            target_c = cost + (1.0 - done) * self.cfg.gamma * qc_next

        qc1 = self.qc1(obs, act)
        qc2 = self.qc2(obs, act)
        loss_qc1 = F.mse_loss(qc1, target_c)
        loss_qc2 = F.mse_loss(qc2, target_c)
        self.qc1_opt.zero_grad(); loss_qc1.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.qc1.parameters(), self.cfg.grad_clip_norm)
        self.qc1_opt.step()
        self.qc2_opt.zero_grad(); loss_qc2.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.qc2.parameters(), self.cfg.grad_clip_norm)
        self.qc2_opt.step()
        return {"loss/qc1": float(loss_qc1.item()), "loss/qc2": float(loss_qc2.item())}

    def _soft_update_cost_targets(self):
        tau = self.cfg.tau
        with torch.no_grad():
            for p, p_targ in zip(self.qc1.parameters(), self.qc1_targ.parameters()):
                p_targ.data.mul_(1 - tau).add_(tau * p.data)
            for p, p_targ in zip(self.qc2.parameters(), self.qc2_targ.parameters()):
                p_targ.data.mul_(1 - tau).add_(tau * p.data)

    def update(self, batch):
        device = self.device
        obs = torch.as_tensor(batch.obs, dtype=torch.float32, device=device)
        act = torch.as_tensor(batch.act, dtype=torch.float32, device=device)
        rew = torch.as_tensor(batch.rew, dtype=torch.float32, device=device)
        next_obs = torch.as_tensor(batch.next_obs, dtype=torch.float32, device=device)
        done = torch.as_tensor(batch.done, dtype=torch.float32, device=device)
        cost = torch.as_tensor(batch.cost, dtype=torch.float32, device=device)

                           
        logs = super()._update_critics(batch)
                             
        logs.update(self._update_cost_critics(obs, act, cost, next_obs, done))

                                                                            
        pi_act, logp = self.actor.sample(obs)
        q = torch.min(self.q1(obs, pi_act), self.q2(obs, pi_act))
        qc = torch.min(self.qc1(obs, pi_act), self.qc2(obs, pi_act))
        policy_obj = self.alpha.detach() * logp - q + self.lambda_c.detach() * qc
        if self.lcfg.use_kl:
            with torch.no_grad():
                mu_old, logstd_old = self.actor_old(obs)
            mu_new, logstd_new = self.actor(obs)
            kl = gaussian_kl(mu_old, logstd_old, mu_new, logstd_new).unsqueeze(-1)
            policy_obj = policy_obj + self.eta_kl.detach() * kl
        loss_pi = policy_obj.mean()
        self.actor_opt.zero_grad(); loss_pi.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.cfg.grad_clip_norm)
        self.actor_opt.step()

                          
        alpha_loss = (-(self.log_alpha) * (logp + self.target_entropy).detach()).mean()
        self.alpha_opt.zero_grad(); alpha_loss.backward(); self.alpha_opt.step()

                      
        with torch.no_grad():
            expected_cost = qc.mean()
        self.lambda_c = torch.clamp(self.lambda_c + self.lcfg.lambda_lr * (expected_cost - self.lcfg.cost_limit), min=0.0)
        if self.lcfg.use_kl:
                                               
            with torch.no_grad():
                mu_old, logstd_old = self.actor_old(obs)
                mu_new, logstd_new = self.actor(obs)
                kl_val = gaussian_kl(mu_old, logstd_old, mu_new, logstd_new).mean()
            self.eta_kl = torch.clamp(self.eta_kl + self.lcfg.kl_lr * (kl_val - self.lcfg.kl_target), min=0.0)
                               
            self.actor_old.load_state_dict(self.actor.state_dict())

                 
        self._soft_update_targets()
        self._soft_update_cost_targets()

        logs.update({
            "loss/pi": float(loss_pi.item()),
            "loss/alpha": float(alpha_loss.item()),
            "alpha": float(self.alpha.item()),
            "lagrange/lambda_c": float(self.lambda_c.item()),
            "lagrange/eta_kl": float(self.eta_kl.item()),
        })
        return logs

