from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn.functional as F

from .sac_agent import SACAgent, SACConfig
from .geometry import calculate_geometric_risk


@dataclass
class GDCSACConfig(SACConfig):
    curvature_weight: float = 1.0
    sigmoid_slope: float = 1.0
    lanczos_steps: int = 6
    adaptive_rate: float = 0.05
    cost_ema_decay: float = 0.05
    target_cost: float = 0.01
    init_kappa0: float = 0.0
                                                               
    bootstrap_risk_weight: float = 0.0
    curvature_method: str = "power"                       
    curvature_mode: str = "action"                             
    curvature_state_weight: float = 1.0
    curvature_action_weight: float = 1.0
                                           
    dynamic_curvature: bool = False
    curv_alpha: float = 0.0                         
    curv_beta: float = 0.0                                  


class GDCSACAgent(SACAgent):
    def __init__(self, cfg: GDCSACConfig, device: torch.device):
        super().__init__(cfg, device)
        self.gdc_cfg: GDCSACConfig = cfg
        self.kappa0 = torch.tensor(cfg.init_kappa0, dtype=torch.float32, device=device)
        self.cost_ema = torch.tensor(0.0, dtype=torch.float32, device=device)

    def _q_targ_min(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
        q1_next = self.q1_targ(obs, act)
        q2_next = self.q2_targ(obs, act)
        return torch.min(q1_next, q2_next).squeeze(-1)       

    def update(self, batch):
        device = self.device
        obs = torch.as_tensor(batch.obs, dtype=torch.float32, device=device)
        act = torch.as_tensor(batch.act, dtype=torch.float32, device=device)
        rew = torch.as_tensor(batch.rew, dtype=torch.float32, device=device).squeeze(-1)
        next_obs = torch.as_tensor(batch.next_obs, dtype=torch.float32, device=device)
        done = torch.as_tensor(batch.done, dtype=torch.float32, device=device).squeeze(-1)
        cost = torch.as_tensor(batch.cost, dtype=torch.float32, device=device).squeeze(-1)

                                                                   
        with torch.no_grad():
            next_act, next_logp = self.actor.sample(next_obs)
                                  
        def q_fn(states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
            with torch.enable_grad():
                q = self._q_targ_min(states, actions)
            return q

                                          
        c_eff = self.gdc_cfg.curvature_weight
                                                                                    
        kappa_for_var = None
        if self.gdc_cfg.dynamic_curvature and (self.gdc_cfg.curv_alpha > 0.0 or self.gdc_cfg.curv_beta > 0.0):
            kappa_for_var, _, _ = calculate_geometric_risk(
                q_fn,
                next_obs.detach(),
                next_act.detach(),
                curvature_weight=self.gdc_cfg.curvature_weight,
                lanczos_steps=self.gdc_cfg.lanczos_steps,
                method=self.gdc_cfg.curvature_method,                
                mode=self.gdc_cfg.curvature_mode,                    
                state_weight=self.gdc_cfg.curvature_state_weight,
                action_weight=self.gdc_cfg.curvature_action_weight,
            )
            var_kappa = float(torch.var(kappa_for_var).item()) if kappa_for_var.numel() > 1 else 0.0
            c_eff = self.gdc_cfg.curvature_weight * (
                1.0 + self.gdc_cfg.curv_alpha * float(self.cost_ema.item()) + self.gdc_cfg.curv_beta * var_kappa
            )

        kappa, grad_norm, concavity = calculate_geometric_risk(
            q_fn,
            next_obs.detach(),
            next_act.detach(),
            curvature_weight=c_eff,
            lanczos_steps=self.gdc_cfg.lanczos_steps,
            method=self.gdc_cfg.curvature_method,                
            mode=self.gdc_cfg.curvature_mode,                    
            state_weight=self.gdc_cfg.curvature_state_weight,
            action_weight=self.gdc_cfg.curvature_action_weight,
        )
        sigma = torch.sigmoid(self.gdc_cfg.sigmoid_slope * (kappa - self.kappa0)).detach()

                                                               
                                                   
        r_neg = torch.minimum(rew, torch.zeros_like(rew))
        r_g = (1.0 - sigma) * rew + sigma * r_neg

                                        
        with torch.no_grad():
            q_next = self._q_targ_min(next_obs, next_act) - self.alpha * next_logp.squeeze(-1)
            if self.gdc_cfg.bootstrap_risk_weight and self.gdc_cfg.bootstrap_risk_weight > 0.0:
                                                   
                w = max(0.0, min(1.0, float(self.gdc_cfg.bootstrap_risk_weight)))
                q_next = (1.0 - w * sigma) * q_next
            target = r_g + (1.0 - done) * self.cfg.gamma * q_next

                       
        q1 = self.q1(obs, act).squeeze(-1)
        q2 = self.q2(obs, act).squeeze(-1)
        loss_q1 = F.mse_loss(q1, target)
        loss_q2 = F.mse_loss(q2, target)
        self.q1_opt.zero_grad(); loss_q1.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.q1.parameters(), self.cfg.grad_clip_norm)
        self.q1_opt.step()
        self.q2_opt.zero_grad(); loss_q2.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.q2.parameters(), self.cfg.grad_clip_norm)
        self.q2_opt.step()

                       
        pi_act, logp = self.actor.sample(obs)
        q_pi = torch.min(self.q1(obs, pi_act), self.q2(obs, pi_act)).squeeze(-1)
        loss_pi = (self.alpha.detach() * logp.squeeze(-1) - q_pi).mean()
        self.actor_opt.zero_grad(); loss_pi.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.cfg.grad_clip_norm)
        self.actor_opt.step()

        alpha_loss = (-(self.log_alpha) * (logp + self.target_entropy).detach()).mean()
        self.alpha_opt.zero_grad(); alpha_loss.backward(); self.alpha_opt.step()

                         
        self._soft_update_targets()

                                                 
        beta = self.gdc_cfg.cost_ema_decay
        self.cost_ema = (1 - beta) * self.cost_ema + beta * cost.mean()
        self.kappa0 = torch.clamp(self.kappa0 + self.gdc_cfg.adaptive_rate * (self.cost_ema - self.gdc_cfg.target_cost), min=0.0)

        logs = {
            "loss/q1": float(loss_q1.item()),
            "loss/q2": float(loss_q2.item()),
            "loss/pi": float(loss_pi.item()),
            "loss/alpha": float(alpha_loss.item()),
            "alpha": float(self.alpha.item()),
            "gdc/kappa_mean": float(kappa.mean().item()),
            "gdc/sigma_mean": float(sigma.mean().item()),
            "gdc/kappa0": float(self.kappa0.item()),
            "gdc/cost_ema": float(self.cost_ema.item()),
        }
        if self.gdc_cfg.dynamic_curvature and (self.gdc_cfg.curv_alpha > 0.0 or self.gdc_cfg.curv_beta > 0.0):
            logs["gdc/c_eff"] = float(c_eff)
            if kappa_for_var is not None and kappa_for_var.numel() > 1:
                logs["gdc/kappa_var"] = float(torch.var(kappa_for_var).item())
        return logs
