# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy
import os
import copy
import random
import time
from dataclasses import dataclass
from datetime import datetime

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    seed: int = 42
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = False
    wandb_project_name: str = "cleanRL"
    wandb_entity: str = None
    capture_video: bool = False
    save_model: bool = False
    upload_model: bool = False
    hf_entity: str = ""

    # Algorithm specific
    env_id: str = "DemonAttackNoFrameskip-v4"  # Atari env
    total_timesteps: int = 15_000_000
    learning_rate: float = 5e-5
    num_envs: int = 1
    buffer_size: int = 200_000
    gamma: float = 0.99
    tau: float = 1.0
    target_network_frequency: int = 2000
    batch_size: int = 32
    start_e: float = 1.0
    end_e: float = 0.01
    exploration_fraction: float = 0.1
    learning_starts: int = 80_000
    train_frequency: int = 4
    use_double_q: bool = True

    # Surrogate options: "mse", "batch_lsc", "elem_lsc", "trimmed_batch_lsc", "huber"
    surrogate_type: str = "huber"
    beta: float = 1.0
    trim_q: float = 0.95  # used only by trimmed_batch_lsc

    # Adaptive beta for TRAINING (can change every step)
    adaptive_beta: bool = False
    beta_c: float = 1.1             # recommend ~[0.7, 1.5]
    beta_min: float = 1e-1
    beta_max: float = 20.0
    beta_eps: float = 1e-6          # avoids division by 0
    beta_quantile: float = 0.95     # quantile for E
    beta_ema: float = 0.1           # EMA smoothing for E_hat

    # α-watcher + evaluation settings (watcher only; algo untouched)
    alpha_val_multiplier: int = 200  # |V_t| = multiplier * batch_size
    alpha_ddqn_targets: bool = True  # build Y_t with DDQN rule if training uses DDQN

    eval_frequency: int = 50000
    eval_episodes: int = 5
    eval_epsilon: float = 0.0  # greedy eval


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)

        # Standard Atari preprocessing (like CleanRL)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)

        env.action_space.seed(seed)
        return env
    return thunk


class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
        )

    def forward(self, x):
        # inputs are uint8 stacked frames
        return self.network(x / 255.0)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


# -------- Batch-LSC (soft L∞): (1/β) log mean_i cosh(β δ_i) ----------
def batch_lsc_loss(td_errors: torch.Tensor, beta: float) -> torch.Tensor:
    u = beta * td_errors
    m = u.abs().max().detach()
    s = 0.5 * (torch.exp(u - m) + torch.exp(-u - m))
    return (m + torch.log(s.mean())) / beta

# -------- Elementwise-LSC (smooth L1 family): mean_i [(1/β) log cosh(β δ_i)] ----------
def elemwise_lsc_loss(td_errors: torch.Tensor, beta: float) -> torch.Tensor:
    x = td_errors * beta
    ax = x.abs()
    logcosh = ax + torch.log1p(torch.exp(-2.0 * ax)) - np.log(2.0)
    return logcosh.mean() / (beta)**2

# -------- Trimmed batch-LSC (soft L∞ on central mass) ----------
def trimmed_batch_lsc(td_errors: torch.Tensor, beta: float, trim_q: float = 0.95) -> torch.Tensor:
    with torch.no_grad():
        c = torch.quantile(td_errors.abs(), trim_q).clamp_min(1e-6)
    keep = (td_errors.abs() <= c)
    d = td_errors[keep]
    if d.numel() == 0:
        d = td_errors
    u = beta * d
    m = u.abs().max().detach()
    s = 0.5 * (torch.exp(u - m) + torch.exp(-u - m))
    return (m + torch.log(s.mean())) / beta


# --- Adaptive beta scheduler (for TRAINING) ---
class BetaScheduler:
    def __init__(self, c: float, q: float, ema: float, beta_min: float, beta_max: float, eps: float):
        self.c = c
        self.q = q
        self.ema = ema
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.eps = eps
        self.E_hat = None
        self._switched = False
        self._violations = 0

    @torch.no_grad()
    def update_and_get(
        self,
        td_errors: torch.Tensor,
        step: int = None,           # <- optional: provide to enable two-phase
        warmup_steps: int = None,   # <- optional: provide to enable two-phase
        beta_warmup: float = 0.1,
        patience: int = 0           # require N consecutive violations to switch
    ) -> float:
        # --- update E_hat (unchanged) ---
        E_q = torch.quantile(td_errors.detach().abs(), self.q)
        if self.E_hat is None:
            self.E_hat = E_q
        else:
            self.E_hat = (1 - self.ema) * self.E_hat + self.ema * E_q
        E = float(self.E_hat.item() if isinstance(self.E_hat, torch.Tensor) else self.E_hat)

        # --- if no warmup args are given, keep original behavior ---
        if step is None or warmup_steps is None:
            raw = self.c / (E + self.eps)
            beta_eff = torch.clamp(torch.as_tensor(raw), self.beta_min, self.beta_max)
            return float(beta_eff.item())

        if not self._switched:
            if step >= warmup_steps:
                self._switched = True
            else:
                violates = (beta_warmup * E) >  self.c
                if violates:
                    self._violations += 1
                    if self._violations >= patience:
                        self._switched = True
                else:
                    self._violations = 0  # reset if no violation

            if not self._switched:
                # Phase 1: constant small beta
                return float(beta_warmup)

        # Phase 2: your original adaptive rule
        raw = self.c / (E + self.eps)
        beta_eff = torch.clamp(torch.as_tensor(raw), self.beta_min, self.beta_max)
        return float(beta_eff.item())
    

    

    def get_E_hat(self) -> float:
        if self.E_hat is None:
            return 0.0
        return float(self.E_hat.item() if isinstance(self.E_hat, torch.Tensor) else self.E_hat)


    def fixed_beta(
            self,
            step: int,
            total_steps: int,
            mode: str = "linear",         # {"constant","linear","cosine","exponential"}
            beta_start: float = 1.0,
            beta_end: float = 10.0,
        ) -> float:
            """
            Return a scheduled beta at `step` in [0, total_steps].
            - constant: beta = beta_start (beta_end ignored)
            - linear:   beta = beta_start + (beta_end - beta_start) * (step/total_steps)
            - cosine:   beta = beta_end + 0.5*(beta_start - beta_end)*(1 + cos(pi * step/total_steps))
                        (smoothly interpolates start -> end)
            - exponential: geometric interpolation:
                        beta = beta_start * (beta_end / beta_start)^(step/total_steps)

            Clamped to [self.beta_min, self.beta_max].
            """
            t = float(np.clip(step, 0, total_steps))
            frac = t / float(total_steps)

            m = mode.lower()
            if m == "constant":
                val = beta_start
            elif m == "linear":
                val = beta_start + (beta_end - beta_start) * frac
            elif m == "cosine":
                # cosine anneal from start->end
                w = 0.5 * (1.0 + np.cos(np.pi * frac))  # w goes 1->0
                val = beta_end + (beta_start - beta_end) * w
            elif m == "exponential":
                # geometric interpolation (robust even if start > end)
                ratio = beta_end / beta_start
                val = beta_start * (ratio ** frac)
            else:
                val = beta_start  # fallback to constant

            # clamp to scheduler bounds
            return float(np.clip(val, self.beta_min, self.beta_max))
    
@torch.no_grad()
def make_fixed_targets(
    q_online_snap: nn.Module, q_target_snap: nn.Module, batch, gamma: float, use_double_q: bool
) -> torch.Tensor:
    next_obs = batch.next_observations
    if use_double_q:
        q_next_online = q_online_snap(next_obs)
        next_actions = q_next_online.argmax(dim=1, keepdim=True)
        q_next_target = q_target_snap(next_obs)
        q_next = q_next_target.gather(1, next_actions).squeeze(1)
    else:
        q_next = q_target_snap(next_obs).max(dim=1).values
    td_target = batch.rewards.flatten() + gamma * q_next * (1.0 - batch.dones.flatten())
    return td_target  # (B,)


def loss_against_fixed_Y(
    q_net: nn.Module,
    obs: torch.Tensor,
    acts: torch.Tensor,
    Y_fixed: torch.Tensor,
    surrogate_type: str,
    beta: float,
    trim_q: float,
) -> tuple:
    q_pred_all = q_net(obs)
    if acts.ndim == 1:
        acts = acts.view(-1, 1)
    q_pred = q_pred_all.gather(1, acts.long()).squeeze(1)
    td = q_pred - Y_fixed
    if surrogate_type == "batch_lsc":
        loss = batch_lsc_loss(td, beta=beta)
    elif surrogate_type == "elem_lsc":
        loss = elemwise_lsc_loss(td, beta=beta)
    elif surrogate_type == "trimmed_batch_lsc":
        loss = trimmed_batch_lsc(td, beta=beta, trim_q=trim_q)
    elif surrogate_type == "huber":
        loss = F.huber_loss(q_pred, Y_fixed, delta=1/beta)
    else:
        loss = F.mse_loss(q_pred, Y_fixed)
    return loss, td


class AlphaWatcher:
    def __init__(self):
        self.active = False
        self.obs = None
        self.acts = None
        self.Y_fixed = None
        self.beta_ref = None
        self.baseline_loss = None

    @torch.no_grad()
    def start_block(
        self,
        q_online_snap: nn.Module,
        q_target_snap: nn.Module,
        rb: ReplayBuffer,
        batch_size: int,
        gamma: float,
        surrogate_type: str,
        beta_ref: float,
        device: torch.device,
        use_double_q: bool,
        trim_q: float,
    ):
        batch = rb.sample(batch_size)
        self.obs = batch.observations.to(device)
        self.acts = batch.actions.to(device)
        self.Y_fixed = make_fixed_targets(
            q_online_snap, q_target_snap, batch, gamma=gamma, use_double_q=use_double_q
        ).to(device)
        self.beta_ref = float(beta_ref)
        loss0, _ = loss_against_fixed_Y(
            q_online_snap, self.obs, self.acts, self.Y_fixed, surrogate_type, self.beta_ref, trim_q
        )
        self.baseline_loss = float(loss0.item())
        self.active = True

    @torch.no_grad()
    def finalize_alpha(self, q_online_current: nn.Module, surrogate_type: str, trim_q: float) -> tuple:
        if not self.active or self.baseline_loss is None or self.baseline_loss <= 0:
            return 1.0, 0.0
        loss_k, td = loss_against_fixed_Y(
            q_online_current, self.obs, self.acts, self.Y_fixed, surrogate_type, self.beta_ref, trim_q
        )
        alpha_final = float(loss_k.item()) / self.baseline_loss
        max_abs_td_final = float(td.abs().max().item())
        self.active = False  # reset after finalizing (consistent with MuJoCo watcher)
        return alpha_final, max_abs_td_final


# --- Validation loss (portable; still useful for ad-hoc checks) ---
@torch.no_grad()
def compute_val_loss(
    q_net: QNetwork,
    target_net: QNetwork,
    batch,
    gamma: float,
    surrogate_type: str,
    beta: float,
    device: torch.device,
    use_double_q: bool = True,
    trim_q: float = 0.95,
) -> tuple:
    obs = torch.as_tensor(batch.observations, device=device).float()
    actions = torch.as_tensor(batch.actions, device=device).long()
    if actions.ndim == 1:
        actions = actions.view(-1, 1)
    next_obs = torch.as_tensor(batch.next_observations, device=device).float()
    rewards = torch.as_tensor(batch.rewards.flatten(), device=device).float()
    dones = torch.as_tensor(batch.dones.flatten(), device=device).float()

    if use_double_q:
        q_next_online = q_net(next_obs)
        next_actions = q_next_online.argmax(dim=1, keepdim=True)
        q_next_target = target_net(next_obs)
        q_next = q_next_target.gather(1, next_actions).squeeze(1)
    else:
        q_next = target_net(next_obs).max(dim=1).values

    target_q = rewards + (1.0 - dones) * gamma * q_next

    q_pred_all = q_net(obs)
    q_pred = q_pred_all.gather(1, actions).squeeze(1)
    td = q_pred - target_q

    if surrogate_type == "batch_lsc":
        loss = batch_lsc_loss(td, beta=beta)
    elif surrogate_type == "elem_lsc":
        loss = elemwise_lsc_loss(td, beta=beta)
    elif surrogate_type == "trimmed_batch_lsc":
        loss = trimmed_batch_lsc(td, beta=beta, trim_q=trim_q)
    elif surrogate_type == "huber":
        loss = F.huber_loss(q_pred, target_q, delta=1/beta)
    else:  # "mse"
        loss = F.mse_loss(q_pred, target_q)
    return float(loss.item()), td


def format_loss_tag(surrogate_type: str, beta: float, trim_q: float = 0.95) -> str:
    st = surrogate_type.lower()
    if st == "mse":
        return "loss=mse"
    elif st == "batch_lsc":
        return f"loss=soft_sup(beta={beta:g})"
    elif st == "elem_lsc":
        return f"loss=soft_Huber(beta={beta:g})"
    elif st == "trimmed_batch_lsc":
        return f"loss=trimmed_batch_lsc(beta={beta:g},q={trim_q:.2f})"
    elif st == "huber":
        return f"loss=Huber(delta={1/beta:.3f})"
    else:
        return f"loss={surrogate_type}"


def _to_torch_obs_single(obs, device):
    """Materialize LazyFrames/list to ndarray and convert to CHW torch tensor with batch dim."""
    x = np.asarray(obs)
    if x.ndim == 3 and x.shape[-1] in (1, 3, 4):
        x = np.transpose(x, (2, 0, 1))
    return torch.from_numpy(x).unsqueeze(0).to(device=device, dtype=torch.float32)

def _to_torch_obs_batch(obs, device):
    """Like above but for a batch from vectorized env: (B,H,W,C)->(B,C,H,W)."""
    x = np.asarray(obs)
    if x.ndim == 4 and x.shape[-1] in (1, 3, 4):
        x = np.transpose(x, (0, 3, 1, 2))
    return torch.from_numpy(x).to(device=device, dtype=torch.float32)

@torch.no_grad()
def evaluate_policy(
    q_net: nn.Module,
    env_id: str,
    episodes: int,
    epsilon: float,
    device: torch.device,
) -> tuple[float, float, float]:
    # Build eval env: same obs preprocessing, but NO EpisodicLife and NO ClipReward
    env = gym.make(env_id)
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    # no ClipRewardEnv
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)
    env = gym.wrappers.FrameStack(env, 4)

    returns, lengths = [], []
    for _ in range(episodes):
        obs, _ = env.reset()
        ep_ret, ep_len = 0.0, 0
        terminated = truncated = False
        while not (terminated or truncated):
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                obs_t = _to_torch_obs_single(obs, device)
                q = q_net(obs_t)
                action = int(torch.argmax(q, dim=1).item())
            obs, r, terminated, truncated, _ = env.step(action)
            ep_ret += float(r)   # unclipped reward
            ep_len += 1
        returns.append(ep_ret)
        lengths.append(ep_len)
    env.close()
    return float(np.mean(returns)), float(np.std(returns)), float(np.mean(lengths))



if __name__ == "__main__":
    args = tyro.cli(Args)
    assert args.num_envs == 1, "vectorized envs are not supported at the moment"

    # Human-readable run name like your JAX script
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    loss_tag = format_loss_tag(args.surrogate_type, args.beta, args.trim_q)
    if args.adaptive_beta and args.surrogate_type != "mse":
        loss_tag += f"_beta=adaptive(C={args.beta_c:g},[{args.beta_min:g},{args.beta_max:g}])"
    if args.use_double_q:
        run_name = "__".join(['ddqn', args.env_id, loss_tag, f"seed{args.seed}", timestamp])
    else:
        run_name = "__".join([args.env_id, loss_tag, f"seed{args.seed}", timestamp])

    if args.track:
        import wandb
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )

    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{k}|{v}|" for k, v in vars(args).items()])),
    )

    # Seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Env
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    q_network = QNetwork(envs).to(device)
    target_network = QNetwork(envs).to(device)
    target_network.load_state_dict(q_network.state_dict())

    optimizer = optim.AdamW(q_network.parameters(), lr=args.learning_rate)

    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        optimize_memory_usage=True,
        handle_timeout_termination=False,
    )

    start_time = time.time()
    obs, _ = envs.reset(seed=args.seed)

    beta_sched = None
    if args.adaptive_beta and args.surrogate_type != "mse":
        beta_sched = BetaScheduler(
            c=args.beta_c,
            q=args.beta_quantile,
            ema=args.beta_ema,
            beta_min=args.beta_min,
            beta_max=args.beta_max,
            eps=args.beta_eps,
        )

    # Alpha watcher (block-end only)
    alpha_watcher = AlphaWatcher()

    for global_step in range(args.total_timesteps):
        # Action selection
        epsilon = linear_schedule(args.start_e, args.end_e, int(args.exploration_fraction * args.total_timesteps), global_step)
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            with torch.no_grad():
                q_values = q_network(_to_torch_obs_batch(obs, device))
                actions = torch.argmax(q_values, dim=1).cpu().numpy()

        # Step
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)

        # Logging episodic stats
        if "final_info" in infos:
            for info in infos["final_info"]:
                if info and "episode" in info:
                    writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                    print(f"global_step={global_step}, episodic_return={info['episode']['r']}")

        # Store with time-limit correction
        real_next_obs = next_obs.copy()
        if "final_observation" in infos:
            for idx, trunc in enumerate(truncations):
                if trunc and infos["final_observation"][idx] is not None:
                    real_next_obs[idx] = infos["final_observation"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        obs = next_obs

        # Learning
        if global_step > args.learning_starts and (global_step % args.train_frequency == 0):
            batch = rb.sample(args.batch_size)

            # --- compute target (no grad) ---
            with torch.no_grad():
                if args.use_double_q:
                    # action selection with ONLINE net
                    q_next_online = q_network(batch.next_observations)
                    next_actions = q_next_online.argmax(dim=1, keepdim=True)
                    # action evaluation with TARGET net
                    q_next_target = target_network(batch.next_observations)
                    q_next = q_next_target.gather(1, next_actions).squeeze(1)
                else:
                    q_next = target_network(batch.next_observations).max(dim=1).values
                td_target = batch.rewards.flatten() + args.gamma * q_next * (1.0 - batch.dones.flatten())

            # --- current Q for chosen actions ---
            q_pred_all = q_network(batch.observations)
            acts = batch.actions
            if acts.ndim == 1:
                acts = acts.view(-1, 1)
            q_pred = q_pred_all.gather(1, acts).squeeze(1)
            td_errors = q_pred - td_target

            # TRAINING temperature can be adaptive per step
            if args.adaptive_beta and args.surrogate_type != "mse":
                beta_train = beta_sched.update_and_get(
                        td_errors=td_errors,
                        step=global_step,
                        warmup_steps=int(0.2 * args.total_timesteps),
                        beta_warmup=0.5,
                        patience=3
                    )
            else:
                beta_train = args.beta

            print(f"step={global_step} | beta_train={beta_train:.3f}", end="\r")

            if args.surrogate_type == "batch_lsc":
                loss = batch_lsc_loss(td_errors, beta=beta_train)
            elif args.surrogate_type == "elem_lsc":
                loss = elemwise_lsc_loss(td_errors, beta=beta_train)
            elif args.surrogate_type == "trimmed_batch_lsc":
                loss = trimmed_batch_lsc(td_errors, beta=beta_train, trim_q=args.trim_q)
            elif args.surrogate_type == "huber":
                loss = F.huber_loss(q_pred, td_target, delta=1/beta_train)
            else:  # "mse"
                loss = F.mse_loss(q_pred, td_target)

            # --- optimize ---
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            # === periodic logging (match MuJoCo) ===
            if global_step % 100 == 0:
                td_train = td_errors.detach()
                train_abs = td_train.abs()
                writer.add_scalar("losses/td_loss", loss.item(), global_step)
                writer.add_scalar("losses/q_values", q_pred.mean().item(), global_step)
                sps = int(global_step / (time.time() - start_time))
                print("SPS:", sps)
                writer.add_scalar("charts/SPS", sps, global_step)
                writer.add_scalar("train/max_abs_td", train_abs.max().item(), global_step)
                writer.add_scalar("train/q95_abs_td", torch.quantile(train_abs, 0.95).item(), global_step)
                writer.add_scalar("train/q99_abs_td", torch.quantile(train_abs, 0.99).item(), global_step)
                writer.add_scalar("train/td_target_mean", td_target.mean().item(), global_step)
                writer.add_scalar("train/td_target_std", td_target.std().item(), global_step)
                if beta_sched is not None:
                    writer.add_scalar("beta/value", beta_train, global_step)
                    writer.add_scalar("beta/E_hat", beta_sched.get_E_hat(), global_step)

            # --- END OF BLOCK: finalize alpha, then sync target, then start next block ---
            if global_step % args.target_network_frequency == 0:
                # 1) finalize alpha for the previous block (if active)
                if alpha_watcher.active:
                    alpha_final, max_abs_td_final = alpha_watcher.finalize_alpha(
                        q_network, args.surrogate_type, args.trim_q
                    )
                    writer.add_scalar("val/alpha_block_final", alpha_final, global_step)
                    writer.add_scalar("val/block_max_abs_td_final", max_abs_td_final, global_step)

                # 2) target sync (Polyak)
                with torch.no_grad():
                    for tp, p in zip(target_network.parameters(), q_network.parameters()):
                        tp.data.lerp_(p.data, args.tau)

                # 3) start next block: snapshots, pick beta_ref, build V_t & Y_t, record baseline
                q_online_snap = copy.deepcopy(q_network).eval()
                q_target_snap = copy.deepcopy(target_network).eval()

                # choose beta_ref per block for α measurement
                if args.adaptive_beta and args.surrogate_type != "mse":
                    beta_ref = beta_train
                else:
                    beta_ref = args.beta
                writer.add_scalar("beta/ref", beta_ref, global_step)

                # init next block
                alpha_watcher.start_block(
                    q_online_snap=q_online_snap,
                    q_target_snap=q_target_snap,
                    rb=rb,
                    batch_size=args.batch_size * args.alpha_val_multiplier,
                    gamma=args.gamma,
                    surrogate_type=args.surrogate_type,
                    beta_ref=beta_ref,
                    device=device,
                    use_double_q=(args.alpha_ddqn_targets and args.use_double_q),
                    trim_q=args.trim_q,
                )

        # --- periodic greedy eval ---
        if (
            global_step >= args.learning_starts
            and args.eval_frequency > 0
            and global_step % args.eval_frequency == 0
        ):
            q_network.eval()
            ret_mean, ret_std, len_mean = evaluate_policy(
                q_net=q_network,
                env_id=args.env_id,
                episodes=args.eval_episodes,
                epsilon=args.eval_epsilon,
                device=device,
            )
            q_network.train()
            writer.add_scalar("eval/episodic_return_greedy", ret_mean, global_step)
            writer.add_scalar("eval/episodic_return_std", ret_std, global_step)
            writer.add_scalar("eval/episodic_length_mean", len_mean, global_step)

    # Optional: final save and offline eval
    if args.save_model:
        model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
        torch.save(q_network.state_dict(), model_path)
        print(f"model saved to {model_path}")
        from cleanrl_utils.evals.dqn_eval import evaluate

        episodic_returns = evaluate(
            model_path,
            make_env,
            args.env_id,
            eval_episodes=10,
            run_name=f"{run_name}-eval",
            Model=QNetwork,
            device=device,
            epsilon=args.eval_epsilon,
        )
        for idx, episodic_return in enumerate(episodic_returns):
            writer.add_scalar("eval/episodic_return", episodic_return, idx)

        if args.upload_model:
            from cleanrl_utils.huggingface import push_to_hub
            repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
            repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
            push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

    envs.close()
    writer.close()
