# rl_framework/algo/hfps.py
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from algo.base import BaseAgent
from models import build_q_network, build_value_network
from utils import ReplayBuffer


@dataclass
class HFPSHyperParams:
    gamma: float = 0.99
    lr_q: float = 3e-4
    lr_u: float = 3e-4
    buffer_size: int = 100_000
    batch_size: int = 64
    train_start: int = 1_000
    train_interval: int = 4
    target_update_interval: int = 1_000
    eps_start: float = 1.0
    eps_end: float = 0.05
    eps_decay: int = 50_000
    log_interval: int = 1_000

    # Stabilization
    delta_clip: float = 10.0        # TD error clipping (for u network stability)
    grad_clip_q: float = 10.0       # gradient clipping for Q
    grad_clip_u: float = 10.0       # gradient clipping for u

    # HFPS-specific
    u_updates_per_step: int = 5     # inner-loop steps for potential network


class HFPSAgent(BaseAgent):

    def __init__(self, env, args):
        super().__init__(env, args)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.obs_shape = env.obs_shape
        self.num_actions = env.num_actions

        hp = HFPSHyperParams(
            gamma=getattr(args, "gamma", 0.99),
            lr_q=getattr(args, "lr_q", getattr(args, "lr", 3e-4)),
            lr_u=getattr(args, "lr_u", getattr(args, "lr", 3e-4)),
            buffer_size=getattr(args, "buffer_size", 100_000),
            batch_size=getattr(args, "batch_size", 64),
            train_start=getattr(args, "train_start", 1_000),
            train_interval=getattr(args, "train_interval", 4),
            target_update_interval=getattr(args, "target_update_interval", 1_000),
            eps_start=getattr(args, "eps_start", 1.0),
            eps_end=getattr(args, "eps_end", 0.05),
            eps_decay=getattr(args, "eps_decay", 50_000),
            log_interval=getattr(args, "log_interval", 1_000),
            delta_clip=getattr(args, "delta_clip", 10.0),
            grad_clip_q=getattr(args, "grad_clip_q", 10.0),
            grad_clip_u=getattr(args, "grad_clip_u", 10.0),
            u_updates_per_step=getattr(args, "u_updates_per_step", 5),
        )
        self.hp = hp

        self.q_net = build_q_network(
            obs_shape=self.obs_shape,
            num_actions=self.num_actions,
            use_cnn=getattr(args, "use_cnn", False),
        ).to(self.device)

        self.target_q_net = build_q_network(
            obs_shape=self.obs_shape,
            num_actions=self.num_actions,
            use_cnn=getattr(args, "use_cnn", False),
        ).to(self.device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.target_q_net.eval()

        self.potential_net = build_value_network(
            obs_shape=self.obs_shape,
            use_cnn=getattr(args, "use_cnn", False),
        ).to(self.device)

        self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=hp.lr_q)
        self.potential_optimizer = optim.Adam(self.potential_net.parameters(), lr=hp.lr_u)

        self.replay = ReplayBuffer(
            obs_shape=self.obs_shape,
            capacity=hp.buffer_size,
            device=self.device,
        )

        self.global_step = 0
        self.eps = hp.eps_start

    def _update_epsilon(self):
        decay = (self.hp.eps_start - self.hp.eps_end) / max(1, self.hp.eps_decay)
        self.eps = max(self.hp.eps_end, self.eps - decay)

    def act(self, obs, evaluate: bool = False):
        obs = np.array(obs, dtype=np.float32)

        if (not evaluate) and (np.random.rand() < self.eps):
            return int(np.random.randint(self.num_actions))

        obs_t = torch.as_tensor(obs, device=self.device).unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_net(obs_t)
            action = q_values.argmax(dim=1).item()
        return int(action)

    # ---------- interaction ----------
    def store_transition(self, obs, action, reward, next_obs, done):
        self.replay.add(obs, action, reward, next_obs, done)

    # ---------- HFPS update ----------
    def update(self):
        self.global_step += 1

        if self.replay.size < self.hp.train_start:
            return

        if self.global_step % self.hp.train_interval != 0:
            return

        obs, actions, rewards, next_obs, dones = self.replay.sample(self.hp.batch_size)

        q_values = self.q_net(obs)                      # (B, A)
        q_sa = q_values.gather(1, actions.unsqueeze(-1)).squeeze(-1)  # (B,)

        with torch.no_grad():
            next_q_values = self.target_q_net(next_obs)      # (B, A)
            next_q_max = next_q_values.max(dim=1)[0]         # (B,)
            td_target = rewards + self.hp.gamma * (1.0 - dones) * next_q_max

        td_error = td_target - q_sa                           # δ(s,a,s')
        delta_clipped = td_error.clamp(-self.hp.delta_clip, self.hp.delta_clip)

        for _ in range(self.hp.u_updates_per_step):
            u_s = self.potential_net(obs)          # (B,)
            u_next = self.potential_net(next_obs)  # (B,)
            du = u_next - self.hp.gamma * u_s     

            diff = du - delta_clipped.detach()
            potential_loss = diff.pow(2).mean()

            self.potential_optimizer.zero_grad()
            potential_loss.backward()
            if self.hp.grad_clip_u is not None:
                nn.utils.clip_grad_norm_(self.potential_net.parameters(),
                                         self.hp.grad_clip_u)
            self.potential_optimizer.step()

        with torch.no_grad():
            u_s = self.potential_net(obs)
            u_next = self.potential_net(next_obs)
            du = u_next - self.hp.gamma * u_s
            du_clipped = du.clamp(-self.hp.delta_clip, self.hp.delta_clip)
            residual = delta_clipped - du_clipped

            delta_norm = delta_clipped.pow(2).mean().sqrt()
            du_norm = du_clipped.pow(2).mean().sqrt()
            residual_norm = residual.pow(2).mean().sqrt()

        q_loss = nn.MSELoss()(q_sa, td_target)

        self.q_optimizer.zero_grad()
        q_loss.backward()
        if self.hp.grad_clip_q is not None:
            nn.utils.clip_grad_norm_(self.q_net.parameters(), self.hp.grad_clip_q)
        self.q_optimizer.step()

        if self.global_step % self.hp.target_update_interval == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())

        self._update_epsilon()

        return {
            "q_loss": q_loss.item(),
            "potential_loss": float(potential_loss.detach().cpu()),
            "delta_norm": float(delta_norm.cpu()),
            "du_norm": float(du_norm.cpu()),
            "residual_norm": float(residual_norm.cpu()),
            "eps": self.eps,
        }




