from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .networks import Actor, QNetwork


@dataclass
class SACConfig:
    obs_dim: int
    act_dim: int
    hidden_dims: Tuple[int, int]
    gamma: float = 0.99
    tau: float = 0.005
    actor_lr: float = 3e-4
    critic_lr: float = 3e-4
    alpha_lr: float = 3e-4
    target_entropy_scale: float = 1.0
    actor_spectral_norm: bool = False
    critic_spectral_norm: bool = False
    grad_clip_norm: float = 0.0


class SACAgent:
    def __init__(self, cfg: SACConfig, device: torch.device):
        self.cfg = cfg
        self.device = device

                                
        self.actor = Actor(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.actor_spectral_norm).to(device)
        self.q1 = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.q2 = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)

        self.q1_targ = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.q2_targ = QNetwork(cfg.obs_dim, cfg.act_dim, cfg.hidden_dims, spectral_norm=cfg.critic_spectral_norm).to(device)
        self.q1_targ.load_state_dict(self.q1.state_dict())
        self.q2_targ.load_state_dict(self.q2.state_dict())

        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=cfg.actor_lr)
        self.q1_opt = torch.optim.Adam(self.q1.parameters(), lr=cfg.critic_lr)
        self.q2_opt = torch.optim.Adam(self.q2.parameters(), lr=cfg.critic_lr)

                                    
        self.log_alpha = torch.tensor(0.0, requires_grad=True, device=device)
        self.alpha_opt = torch.optim.Adam([self.log_alpha], lr=cfg.alpha_lr)
        self.target_entropy = -cfg.target_entropy_scale * float(cfg.act_dim)

    @property
    def alpha(self) -> torch.Tensor:
        return self.log_alpha.exp()

    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            if deterministic:
                act = self.actor.deterministic(obs_t)
            else:
                act, _ = self.actor.sample(obs_t)
        return act.squeeze(0).cpu().numpy()

    def _update_critics(self, batch):
        obs = torch.as_tensor(batch.obs, dtype=torch.float32, device=self.device)
        act = torch.as_tensor(batch.act, dtype=torch.float32, device=self.device)
        rew = torch.as_tensor(batch.rew, dtype=torch.float32, device=self.device)
        next_obs = torch.as_tensor(batch.next_obs, dtype=torch.float32, device=self.device)
        done = torch.as_tensor(batch.done, dtype=torch.float32, device=self.device)

        with torch.no_grad():
            next_act, next_logp = self.actor.sample(next_obs)
            q1_next = self.q1_targ(next_obs, next_act)
            q2_next = self.q2_targ(next_obs, next_act)
            q_next = torch.min(q1_next, q2_next) - self.alpha * next_logp
            target = rew + (1.0 - done) * self.cfg.gamma * q_next

        q1 = self.q1(obs, act)
        q2 = self.q2(obs, act)
        loss_q1 = F.mse_loss(q1, target)
        loss_q2 = F.mse_loss(q2, target)
        self.q1_opt.zero_grad(); loss_q1.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.q1.parameters(), self.cfg.grad_clip_norm)
        self.q1_opt.step()
        self.q2_opt.zero_grad(); loss_q2.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.q2.parameters(), self.cfg.grad_clip_norm)
        self.q2_opt.step()

        return {
            "loss/q1": float(loss_q1.item()),
            "loss/q2": float(loss_q2.item()),
        }

    def _update_actor_and_alpha(self, obs: torch.Tensor):
        act, logp = self.actor.sample(obs)
        q1 = self.q1(obs, act)
        q2 = self.q2(obs, act)
        q = torch.min(q1, q2)
        loss_pi = (self.alpha.detach() * logp - q).mean()
        self.actor_opt.zero_grad(); loss_pi.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.cfg.grad_clip_norm)
        self.actor_opt.step()

                     
        alpha_loss = (-(self.log_alpha) * (logp + self.target_entropy).detach()).mean()
        self.alpha_opt.zero_grad(); alpha_loss.backward(); self.alpha_opt.step()

        return {
            "loss/pi": float(loss_pi.item()),
            "loss/alpha": float(alpha_loss.item()),
            "alpha": float(self.alpha.item()),
        }

    def _soft_update_targets(self):
        tau = self.cfg.tau
        with torch.no_grad():
            for p, p_targ in zip(self.q1.parameters(), self.q1_targ.parameters()):
                p_targ.data.mul_(1 - tau).add_(tau * p.data)
            for p, p_targ in zip(self.q2.parameters(), self.q2_targ.parameters()):
                p_targ.data.mul_(1 - tau).add_(tau * p.data)

    def update(self, batch):
                 
        logs = self._update_critics(batch)
                       
        obs = torch.as_tensor(batch.obs, dtype=torch.float32, device=self.device)
        logs.update(self._update_actor_and_alpha(obs))
                 
        self._soft_update_targets()
        return logs

    def state_dict(self):
        return {
            "actor": self.actor.state_dict(),
            "q1": self.q1.state_dict(),
            "q2": self.q2.state_dict(),
            "q1_targ": self.q1_targ.state_dict(),
            "q2_targ": self.q2_targ.state_dict(),
            "actor_opt": self.actor_opt.state_dict(),
            "q1_opt": self.q1_opt.state_dict(),
            "q2_opt": self.q2_opt.state_dict(),
            "log_alpha": self.log_alpha.detach().cpu(),
            "alpha_opt": self.alpha_opt.state_dict(),
            "cfg": self.cfg,
        }

    def load_state_dict(self, sd):
        self.actor.load_state_dict(sd["actor"]) ; self.q1.load_state_dict(sd["q1"]) ; self.q2.load_state_dict(sd["q2"]) ;
        self.q1_targ.load_state_dict(sd.get("q1_targ", sd["q1"])) ; self.q2_targ.load_state_dict(sd.get("q2_targ", sd["q2"]))
        self.actor_opt.load_state_dict(sd["actor_opt"]) ; self.q1_opt.load_state_dict(sd["q1_opt"]) ; self.q2_opt.load_state_dict(sd["q2_opt"]) ;
        if isinstance(sd.get("log_alpha"), torch.Tensor):
            self.log_alpha.data.copy_(sd["log_alpha"].to(self.device).log() if sd["log_alpha"].dim()==0 else sd["log_alpha"].to(self.device))
        self.alpha_opt.load_state_dict(sd["alpha_opt"]) ;
