"""
FPVR Agent with Experience Replay
Future-Past Visitation Redundancy (FPVR) exploration agent with experience replay

References:
- This supplementary package is self-contained; no external code is required.
"""
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

# Ensure local modules are importable
import os
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
if _THIS_DIR not in sys.path:
    sys.path.insert(0, _THIS_DIR)

from model import FPVRNetwork, make_q_network
from replay_buffer import ReplayBuffer, PrioritizedReplayBuffer


class FPVRAgent:
    """
    Future-Past Visitation Redundancy (FPVR) exploration agent with experience replay
    
    Mechanism:
    - φ: state features (CNN encoder)
    - φ̃: whitened features (ZCA whitening)
    - ψ(s,a): successor features (slow time-scale future-visitation representation)
    - c: persistence representation (discounted accumulator of past whitened features)
    - redundancy: cosine similarity between ψ and c; prefer actions with lower redundancy
    """
    
    def __init__(self, state_shape, n_actions, config):
        self.config = config
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.n_actions = n_actions
        self.state_shape = state_shape

        # Backward compatibility: old configs may lack q_net_type.
        # Ensure a single canonical key exists in self.config and is saved in checkpoints.
        if "q_net_type" not in self.config or self.config.get("q_net_type") is None:
            self.config["q_net_type"] = "nature"
        self.q_net_type = str(self.config.get("q_net_type", "nature")).lower()
        
        # ========== FPVR Network input channels (optional) ==========
        # Env/replay-buffer observations are stacked with `frame_stack` channels (state_shape[0]).
        # Here we optionally feed only the last K channels into the FPVR network to reduce
        # sensitivity to stack alignment (e.g., K=1 means only the current frame).
        self.env_frame_stack = int(state_shape[0])
        num_sf_channel = config.get("num_sf_channel", None)
        self.num_sf_channel = int(num_sf_channel) if num_sf_channel is not None else self.env_frame_stack
        if not (1 <= self.num_sf_channel <= self.env_frame_stack):
            raise ValueError(
                f"num_sf_channel must be in [1, frame_stack={self.env_frame_stack}], got {self.num_sf_channel}"
            )
        self.fpvr_state_shape = (self.num_sf_channel, state_shape[1], state_shape[2])

        # ========== Network ==========
        self.network = FPVRNetwork(self.fpvr_state_shape, n_actions, phi_dim=config["phi_dim"]).to(self.device)
        sf_lr = float(config.get("sf_lr", config.get("lr", 2.5e-4)))
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=sf_lr)

        # ========== Q-Network for DDQN ==========
        # Always enabled (no fpvr_only mode).
        self.dqn_type = str(config.get("dqn_type", "dqn")).lower()
        if self.dqn_type not in ("dqn", "ddqn"):
            self.dqn_type = "dqn"
        self.q_gamma = float(config.get("q_gamma", 0.99))
        self.q_target_update = int(config.get("q_target_update", 10000))
        self.q_train_step = 0
        self.q_net = make_q_network(state_shape, n_actions, self.q_net_type).to(self.device)
        self.q_target = make_q_network(state_shape, n_actions, self.q_net_type).to(self.device)
        self.q_target.load_state_dict(self.q_net.state_dict())
        self.q_target.eval()
        # Paper DQN RMSprop: step-size=2.5e-4, epsilon=0.01, decay(alpha)=0.95, momentum=0.0
        q_lr = float(config.get("q_lr", 2.5e-4))
        self.q_optimizer = torch.optim.RMSprop(
            self.q_net.parameters(), lr=q_lr, momentum=0.0, alpha=0.95, eps=0.01
        )
        
        # ========== FPVR Parameters ==========
        # Persistence representation decay
        self.lambda_c = float(config.get("fpvr_lambda_c", 0.9))
        self.sf_gamma = float(config["sf_gamma"])
        self.phi_dim = int(config["phi_dim"])
        # Note: fixed-state c mode removed; redundancy is always cosine similarity.
        
        # ========== Persistence Representation c (past visitation accumulator) ==========
        # The persistence representation c accumulates discounted past features φ̃.
        # It represents the slow-timescale past visitation pattern.
        # DDQN training uses a single environment by default, so keep a single c vector.
        # Shape is [1, phi_dim] to keep downstream code batching-friendly.
        self.c_vec = torch.zeros(1, self.phi_dim, device=self.device)
        
        # ========== ZCA Whitening Statistics ==========
        self.whitening_update_every = int(config["whitening_update_every"])
        self.whitening_update_count = 0
        self.whitening_ema_alpha = float(config["whitening_ema_alpha"])
        self.whitening_eps = float(config.get("whitening_eps", 1e-5))
        # Note: statistics are always estimated from the sliding window cov_buffer (no online mode).
        self.running_mean = torch.zeros(self.phi_dim, device=self.device)
        self.running_cov = torch.eye(self.phi_dim, device=self.device)
        self.whitening_matrix = torch.eye(self.phi_dim, device=self.device)

        # Sliding window buffer: store latest cov_buffer features (in interaction time order)
        self.cov_buffer_size = int(config.get("cov_buffer", 50000))
        self.cov_buffer = torch.zeros(self.cov_buffer_size, self.phi_dim, device=self.device)
        self.cov_buffer_ptr = 0
        self.cov_buffer_filled = 0
        
        # ========== Replay Buffer ==========
        # If mixed_mc is enabled, pass n_step and gamma to enable precomputation
        use_mixed_mc = bool(config.get("mixed_mc", False))
        use_full_mc = bool(config.get("full_mc", False)) if use_mixed_mc else False
        if use_mixed_mc and use_full_mc and ("n_step" in config):
            # n_step always exists in our argparse config; warn to avoid confusion.
            print(f"[MixedMC] full_mc=True -> ignoring n_step={config.get('n_step')}.")
        n_step = None
        if use_mixed_mc and (not use_full_mc):
            n_step = config.get("n_step", None)
        gamma = self.q_gamma if use_mixed_mc else None
        
        if config["prioritized_replay"]:
            self.replay_buffer = PrioritizedReplayBuffer(
                config["buffer_size"],
                state_shape,
                alpha=config["prioritized_alpha"],
                beta=config["prioritized_beta"],
                device=self.device,
                n_step=n_step,
                gamma=gamma,
                full_mc=use_full_mc,
            )
        else:
            self.replay_buffer = ReplayBuffer(
                config["buffer_size"],
                state_shape,
                device=self.device,
                n_step=n_step,
                gamma=gamma,
                full_mc=use_full_mc,
            )
        
        # ========== Training State ==========
        self.train_step_count = 0
        self.env_step_count = 0  # environment step counter (for logging)
        self.policy_type = str(config.get("policy_type", "q_bias"))
        if self.policy_type not in ("q_bias", "filtered_zscore"):
            self.policy_type = "q_bias"

        # Fixed alpha for redundancy term (single source of truth)
        self.policy_alpha = float(config.get("policy_alpha", 0.05))

        # Filter threshold for Q (only used when policy_type == filtered_zscore)
        self.q_abs_threshold = float(config.get("q_abs_threshold", 0.1))
        
        print("[FPVR-DQN] Initialized with single-environment training")
        print(f"[FPVR] φ_dim={self.phi_dim}, alpha={self.policy_alpha}, sf_gamma={self.sf_gamma}")
        print(f"[FPVR] FPVR network input channels: {self.num_sf_channel} (env stack={self.env_frame_stack})")
        print(f"[FPVR] SR target: {config['sf_target']} (done is considered)")
        print(f"[FPVR] Buffer size: {config['buffer_size']}, prioritized: {config['prioritized_replay']}")
        print(f"[FPVR-DQN] Q-learning target type: {self.dqn_type.upper()}")
        print("[FPVR-DQN] Combining Q-values and FPVR redundancy for action selection")
        print(f"[FPVR-DQN] policy_type = {self.policy_type}")
        print(f"[FPVR-DQN] policy_alpha = {self.policy_alpha}")
        if self.policy_type == "filtered_zscore":
            print(f"[FPVR-DQN] q_abs_threshold = {self.q_abs_threshold}")

    # Note: fixed-state c mode removed.
    
    def _add_to_cov_buffer(self, phi_batch: torch.Tensor):
        """Add features to sliding window buffer used for covariance/mean estimation."""
        with torch.no_grad():
            if phi_batch.ndim == 1:
                phi_batch = phi_batch.unsqueeze(0)
            B = phi_batch.size(0)
            for i in range(B):
                self.cov_buffer[self.cov_buffer_ptr] = phi_batch[i]
                self.cov_buffer_ptr = (self.cov_buffer_ptr + 1) % self.cov_buffer_size
                self.cov_buffer_filled = min(self.cov_buffer_filled + 1, self.cov_buffer_size)
    
    def _update_whitening_matrix(self, phi_batch):
        """Update ZCA whitening matrix (EMA).
        
        Note: statistics are always updated from cov_buffer (no online mode).
        """
        with torch.no_grad():
            alpha = self.whitening_ema_alpha

            # Use a recent window from cov_buffer for whitening statistics
            if self.cov_buffer_filled == 0:
                return  # no samples yet
            # To control SVD cost, sample at most 2048 feature vectors
            B = min(self.cov_buffer_filled, 2048)
            idxs = torch.randint(0, self.cov_buffer_filled, (B,), device=self.device)
            phi_batch = self.cov_buffer[idxs]
            
            if phi_batch.ndim == 1:
                phi_batch = phi_batch.unsqueeze(0)  # [1, phi_dim]
            
            B = phi_batch.size(0)
            
            # Multi-sample batch update (buffer mode)
            batch_mean = phi_batch.mean(dim=0)
            phi_centered = phi_batch - batch_mean
            batch_cov = (phi_centered.T @ phi_centered) / B
            
            self.running_mean = (1 - alpha) * self.running_mean + alpha * batch_mean
            self.running_cov = (1 - alpha) * self.running_cov + alpha * batch_cov
            
            # Recompute whitening matrix periodically
            if self.whitening_update_count % self.whitening_update_every == 0:
                # Paper-aligned: ZCA whitening uses (Σ + εI)^(-1/2).
                eps = float(self.whitening_eps)
                cov_reg = self.running_cov + eps * torch.eye(self.phi_dim, device=self.device)
                U, S, _ = torch.svd(cov_reg)
                # Numerical guard only (should be >= eps); keeps behavior stable if tiny negatives appear.
                S_inv_sqrt = 1.0 / torch.sqrt(S.clamp(min=1e-12))
                self.whitening_matrix = U @ torch.diag(S_inv_sqrt) @ U.T
            
            self.whitening_update_count += 1
    
    def _apply_whitening(self, phi_raw):
        """Apply ZCA whitening: φ̃ = W^T (φ - μ)"""
        phi_centered = phi_raw - self.running_mean
        phi_tilde = phi_centered @ self.whitening_matrix
        return phi_tilde

    def _sf_input(self, obs: torch.Tensor) -> torch.Tensor:
        """Select the last num_sf_channel frames from a stacked observation [B,C,H,W]."""
        if self.num_sf_channel == self.env_frame_stack:
            return obs
        return obs[:, -self.num_sf_channel:, :, :]
    
    def _compute_redundancy(self, psi_all, c_used, return_details=False):
        """
        Compute raw FPVR redundancy score between ψ(s,a) and c (before normalization).
        
        Args:
            psi_all: [B, A, D] successor features (whitened)
            c_used: [B, D] accumulated feature vector (whitened)
            return_details: if True, return dict with intermediate values
        
        Returns:
            raw_score: [B, A] unnormalized redundancy score
            details: dict (only if return_details=True)
        """
        with torch.no_grad():
            psi_det = psi_all.detach()

            dot = torch.einsum('bad,bd->ba', psi_det, c_used)
            psi_norm = psi_det.norm(dim=2)  # [B,A]
            c_norm = c_used.norm(dim=1, keepdim=True)  # [B,1]
            cosine_sim = dot / (psi_norm * (c_norm + 1e-8) + 1e-8)

            # Note: redundancy is cosine similarity
            raw_score = cosine_sim

            if return_details:
                details = {
                    'cosine_sim': cosine_sim,
                    'cosine_for_redundancy': raw_score,
                    'psi_norm': psi_norm,
                    'redundancy_base': cosine_sim,
                    'c_norm': c_norm,
                }
                return raw_score, details
        return raw_score

    @staticmethod
    def _normalize_redundancy(redundancy_raw):
        """Apply a single z-score normalization to redundancy scores."""
        mean = redundancy_raw.mean(dim=1, keepdim=True)
        std = redundancy_raw.std(dim=1, keepdim=True)
        return (redundancy_raw - mean) / (std + 1e-6)
    
    @staticmethod
    def _normalize_q_values(q_values):
        """Apply z-score normalization to Q-values."""
        mean = q_values.mean(dim=1, keepdim=True)
        std = q_values.std(dim=1, keepdim=True)
        return (q_values - mean) / (std + 1e-6)

    # (adaptive alpha logic removed by design; policy_type is now sum_zscore / filtered_zscore)
    
    @torch.no_grad()
    def select_action(self, state, epsilon=0.0):
        """
        Select action using combined decision (Q - alpha * redundancy).
        
        Args:
            state: [C, H, W] or [1, C, H, W] observation
            epsilon: epsilon-greedy exploration (for Q-network decision)
        
        Returns:
            action: selected action
            phi_tilde: whitened feature (for c update)
        """
        if isinstance(state, np.ndarray):
            if state.ndim == 3:
                state = np.expand_dims(state, 0)
            state = torch.from_numpy(state).to(self.device)
        
        # Forward pass through the SF/FPVR network (needed for c update and SR training)
        state_sf = self._sf_input(state)
        phi_raw, _ = self.network(state_sf)

        # Maintain cov_buffer window for whitening statistics
        self._add_to_cov_buffer(phi_raw.detach())
        
        # Note: online statistics mode removed.
        
        phi_tilde = self._apply_whitening(phi_raw)

        # ===== Combined strategy (always): Q - alpha * redundancy =====
        # Compute ψ from whitened features
        _, psi_all = self.network(state_sf, phi_whitened=phi_tilde)

        # Single-env: always use c_vec[0]
        c_used = self.c_vec[0:1]

        # Compute raw redundancy score
        if self.config.get("verbose", False):
            redundancy_raw, details = self._compute_redundancy(psi_all, c_used, return_details=True)
        else:
            redundancy_raw = self._compute_redundancy(psi_all, c_used)

        redundancy = self._normalize_redundancy(redundancy_raw)

        if self.config.get("verbose", False):
            details['redundancy_zscore'] = redundancy
            self._print_action_info(details)

        # Q(s,a) (raw)
        q_values = self.q_net(state).detach()  # [1, A]

        # policy_type:
        # - q_bias: scores = raw_Q - alpha * zscore(redundancy)
        # - filtered_zscore: if abs(max_a Q) < threshold => use ONLY redundancy; else use zscore(Q) - alpha*zscore(redundancy)
        alpha = float(self.policy_alpha)
        if self.policy_type == "filtered_zscore":
            q_absmax = float(q_values.abs().max().item())
            if q_absmax < float(self.q_abs_threshold):
                scores = -alpha * redundancy
            else:
                q_values_norm = self._normalize_q_values(q_values)
                scores = q_values_norm - alpha * redundancy
        else:
            scores = q_values - alpha * redundancy

        if np.random.rand() < epsilon:
            action = np.random.randint(0, self.n_actions)
        else:
            action = int(torch.argmax(scores, dim=1).item())

        # Increment environment step counter
        self.env_step_count += 1
        
        return action, phi_tilde.squeeze(0).cpu().numpy()
    
    def _print_action_info(self, details):
        """Print detailed action selection info"""
        cosine_raw = details['cosine_sim'][0].cpu().numpy()
        cosine_used = details.get('cosine_for_redundancy', details['cosine_sim'])[0].cpu().numpy()
        redundancy_z = details['redundancy_zscore'][0].cpu().numpy()
        psi_norm = details['psi_norm'][0].cpu().numpy()
        c_norm = details['c_norm'][0].cpu().numpy()
        
        print("\n" + "="*80)
        print("FPVR Action Selection")
        print("="*80)
        print(f"||c|| = {c_norm[0]:.4f} | α = {self.policy_alpha:.4f}")
        print("-"*80)
        extra_col = "Cosine*" if 'cosine_for_redundancy' in details else None
        if extra_col:
            print(f"{'Action':<8} {'Cos(raw)':<12} {'Cos(use)':<12} {'Z-score':<12} {'||ψ||':<12}")
        else:
            print(f"{'Action':<8} {'Cosine':<12} {'Z-score':<12} {'||ψ||':<12}")
        print("-"*80)
        for a in range(min(self.n_actions, 10)):
            if extra_col:
                print(f"{a:<8} {cosine_raw[a]:>11.4f} {cosine_used[a]:>11.4f} {redundancy_z[a]:>11.4f} {psi_norm[a]:>11.4f}")
            else:
                print(f"{a:<8} {cosine_raw[a]:>11.4f} {redundancy_z[a]:>11.4f} {psi_norm[a]:>11.4f}")
        if self.n_actions > 10:
            print(f"... ({self.n_actions - 10} more actions)")
        print("="*80)
    
    def update_c(self, phi_tilde):
        """Update c vector: c ← λ·c + φ̃"""
        with torch.no_grad():
            phi_t = torch.from_numpy(phi_tilde).to(self.device) if isinstance(phi_tilde, np.ndarray) else phi_tilde
            if phi_t.ndim == 1:
                phi_t = phi_t.unsqueeze(0)
            self.c_vec[0].mul_(self.lambda_c).add_(phi_t.squeeze(0))
    
    def reset_c(self):
        """Reset c vector to zero (e.g., at episode boundaries)."""
        with torch.no_grad():
            self.c_vec.zero_()
    
    def train(self):
        """
        Train SR from replay buffer
        
        Returns:
            dict with losses: {sr_loss, total_loss}
        """
        if len(self.replay_buffer) < self.config["learning_starts"]:
            return None
        
        # ========== Sample batch for FPVR network ==========
        # Critical: the FPVR (SR) network always uses *uniform* sampling, independent of PER.
        # Even if PER is enabled, SR training should remain diverse and stable.
        if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
            # If using PER for Q-learning, still use uniform sampling for FPVR (SR) training.
            # We construct a temporary uniform batch.
            batch_size = self.config["batch_size"]
            # Note: since this is a PrioritizedReplayBuffer, we implement uniform sampling explicitly here.
            uniform_idxs = []
            attempts = 0
            max_attempts = batch_size * 10
            while len(uniform_idxs) < batch_size and attempts < max_attempts:
                idx = np.random.randint(0, self.replay_buffer.size)
                if self.replay_buffer._has_enough_history(idx):
                    if idx not in uniform_idxs:
                        uniform_idxs.append(idx)
                attempts += 1
            # If valid indices are insufficient, allow some fallback indices
            while len(uniform_idxs) < batch_size:
                idx = np.random.randint(0, self.replay_buffer.size)
                if idx not in uniform_idxs:
                    uniform_idxs.append(idx)
            
            uniform_idxs = np.array(uniform_idxs[:batch_size], dtype=np.int32)
            
            # Build batch (use _encode_observation to construct obs and next_obs)
            obs_list = []
            next_obs_list = []
            for idx in uniform_idxs:
                obs = self.replay_buffer._encode_observation(idx)
                next_idx = (idx + 1) % self.replay_buffer.capacity
                done = self.replay_buffer.done_buf[idx]
                if done > 0.5:
                    # done=True: next_obs should be the first state of the next episode
                    next_frame = self.replay_buffer.frames_buf[next_idx, 0]
                    next_obs = np.stack([next_frame] * self.replay_buffer.frame_stack, axis=0)
                else:
                    next_obs = self.replay_buffer._encode_observation(next_idx)
                obs_list.append(obs)
                next_obs_list.append(next_obs)
            
            obs_batch = torch.from_numpy(np.stack(obs_list, axis=0)).to(self.device)
            next_obs_batch = torch.from_numpy(np.stack(next_obs_list, axis=0)).to(self.device)
            uniform_batch = {
                'obs': obs_batch,
                'actions': torch.from_numpy(self.replay_buffer.action_buf[uniform_idxs]).to(self.device),
                'rewards': torch.from_numpy(self.replay_buffer.reward_buf[uniform_idxs]).to(self.device),
                'next_obs': next_obs_batch,
                'dones': torch.from_numpy(self.replay_buffer.done_buf[uniform_idxs]).to(self.device),
                'idxs': uniform_idxs,
            }
            
            # If n-step is enabled, attach n-step fields
            if self.replay_buffer.use_n_step:
                uniform_batch['n_step_returns'] = torch.from_numpy(self.replay_buffer.n_step_return_buf[uniform_idxs]).to(self.device)
                # Route A: reconstruct n-step next_obs from indices.
                n_next_obs_list = []
                next_idxs = self.replay_buffer.n_step_next_idx_buf[uniform_idxs]
                for ni in next_idxs:
                    n_next_obs_list.append(self.replay_buffer._encode_observation(int(ni)))
                uniform_batch['n_step_next_obs'] = torch.from_numpy(np.stack(n_next_obs_list, axis=0)).to(self.device)
                uniform_batch['n_step_dones'] = torch.from_numpy(self.replay_buffer.n_step_done_buf[uniform_idxs]).to(self.device)
                uniform_batch['n_step_valid'] = torch.from_numpy(self.replay_buffer.n_step_valid_buf[uniform_idxs]).to(self.device)
            fpvr_batch = uniform_batch
        else:
            # Standard uniform replay buffer: sample directly
            fpvr_batch = self.replay_buffer.sample(self.config["batch_size"])
        
        # Slice stacked observations for the FPVR network (num_sf_channel)
        sf_obs = self._sf_input(fpvr_batch['obs'])
        sf_next_obs = self._sf_input(fpvr_batch['next_obs'])

        # Forward pass for FPVR (SR) using the uniformly sampled batch
        phi_raw, _ = self.network(sf_obs)
        
        # Update whitening statistics from cov_buffer window
        # cov_buffer maintenance is done during interaction (select_action)
        self._update_whitening_matrix(phi_raw.detach())
        
        # Apply whitening
        phi_tilde = self._apply_whitening(phi_raw)
        
        # Compute ψ from whitened features
        _, psi_all = self.network(sf_obs, phi_whitened=phi_tilde)
        
        # ========== SR Loss ==========
        phi_raw_next, _ = self.network(sf_next_obs)
        phi_tilde_next = self._apply_whitening(phi_raw_next)
        _, psi_next_all = self.network(sf_next_obs, phi_whitened=phi_tilde_next)
        
        with torch.no_grad():
            # SR target selection
            if self.config["sf_target"] == "min_redundancy":
                # Use the current persistence vector c (updated online during interaction)
                # Single-env uses c_vec[0]; extend as needed for multi-env
                B = phi_tilde.size(0)
                c_current = self.c_vec[0].unsqueeze(0).expand(B, -1)
                redundancy_next_raw = self._compute_redundancy(psi_next_all, c_current)
                redundancy_next = self._normalize_redundancy(redundancy_next_raw)
                a_min = torch.argmin(redundancy_next, dim=1)
                psi_exp = psi_next_all[torch.arange(psi_next_all.size(0), device=psi_next_all.device), a_min]
            else:
                # Uniform policy expectation
                psi_exp = psi_next_all.mean(dim=1)
        
        # SR target: φ̃(s) + γ (1-d) E[ψ(s',a')]
        gamma = self.sf_gamma
        sr_target = phi_tilde.detach() + gamma * (1.0 - fpvr_batch['dones'].unsqueeze(-1)) * psi_exp.detach()
        
        # Current ψ for taken action
        psi_curr = psi_all[torch.arange(psi_all.size(0), device=psi_all.device), fpvr_batch['actions'].long()]
        sr_loss = F.mse_loss(psi_curr, sr_target.detach())
        
        # ========== Total SR Loss ==========
        total_loss = self.config["sr_coeff"] * sr_loss
        
        # ========== Optimize FPVR network ==========
        self.optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=10.0)
        self.optimizer.step()
        self.train_step_count += 1

        q_loss_val = None
        q_batch = None  # initialized; used for PER updates

        # ========== DDQN Loss & Update ==========
        # Critical: Q-network may use PER (if enabled) to prioritize learning from reward-relevant transitions
        if self.q_net is not None and self.q_optimizer is not None:
            # If PER is enabled, sample with priorities; otherwise sample uniformly
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                q_batch = self.replay_buffer.sample(self.config["batch_size"])  # PER sampling
            else:
                q_batch = self.replay_buffer.sample(self.config["batch_size"])  # uniform sampling
            
            obs = q_batch['obs'].to(self.device)
            next_obs = q_batch['next_obs'].to(self.device)
            actions = q_batch['actions'].long().to(self.device)
            rewards = q_batch['rewards'].to(self.device)
            dones = q_batch['dones'].to(self.device)

            # Q(s,a)
            q = self.q_net(obs)
            q_a = q.gather(1, actions.view(-1, 1)).squeeze(1)

            with torch.no_grad():
                # 1-step TD target: DQN vs Double DQN differs in how we choose a* at s'.
                if self.dqn_type == "ddqn":
                    q_next_online = self.q_net(next_obs)
                    a_star = torch.argmax(q_next_online, dim=1, keepdim=True)
                    q_next_target = self.q_target(next_obs).gather(1, a_star).squeeze(1)
                else:
                    # Classic DQN target: max_a Q_target(s', a)
                    q_next_target = self.q_target(next_obs).max(dim=1).values
                target_1step = rewards + (1.0 - dones) * (self.q_gamma * q_next_target)
                
                # Mixed Monte Carlo: combine 1-step and n-step targets
                if self.config.get("mixed_mc", False):
                    full_mc = bool(self.config.get("full_mc", False))
                    n_step = int(self.config.get("n_step", 10))
                    mixed_mc_weight = self.config.get("mixed_mc_weight", 0.5)
                    
                    if full_mc:
                        # Full-episode MC return-to-go: no bootstrap term.
                        if 'mc_returns' in q_batch and 'mc_valid' in q_batch:
                            mc_returns = q_batch['mc_returns']
                            mc_valid = q_batch['mc_valid']
                        else:
                            raise ValueError("full_mc=True but replay buffer did not provide mc_returns/mc_valid")
                        target_mc = mc_returns
                        target = torch.where(
                            mc_valid,
                            (1.0 - mixed_mc_weight) * target_1step + mixed_mc_weight * target_mc,
                            target_1step
                        )
                    else:
                        # Fixed-horizon n-step target (with bootstrap).
                        # From batch, read precomputed n-step info.
                        if 'n_step_returns' in q_batch:
                            n_step_returns = q_batch['n_step_returns']
                            n_step_dones = q_batch['n_step_dones']
                            n_step_next_obs = q_batch['n_step_next_obs']
                            n_step_valid = q_batch['n_step_valid']
                        else:
                            # Backward-compat dynamic computation (not recommended).
                            idxs = q_batch['idxs'].cpu().numpy()
                            n_step_returns, n_step_dones, n_step_next_obs, n_step_valid = \
                                self.replay_buffer.get_n_step_return(idxs, n_step, self.q_gamma)
                            n_step_next_obs = n_step_next_obs.to(self.device)

                        if self.dqn_type == "ddqn":
                            q_next_online_n = self.q_net(n_step_next_obs)
                            a_star_n = torch.argmax(q_next_online_n, dim=1, keepdim=True)
                            q_next_target_n = self.q_target(n_step_next_obs).gather(1, a_star_n).squeeze(1)
                        else:
                            q_next_target_n = self.q_target(n_step_next_obs).max(dim=1).values
                        gamma_n = float(self.q_gamma) ** int(n_step)
                        target_nstep = n_step_returns + (1.0 - n_step_dones) * (gamma_n * q_next_target_n)
                    
                    # Mixed target: use the mixed target for valid n-step samples, otherwise fall back to 1-step
                        target = torch.where(
                            n_step_valid,
                            (1.0 - mixed_mc_weight) * target_1step + mixed_mc_weight * target_nstep,
                            target_1step
                        )
                else:
                    # Standard 1-step TD target
                    target = target_1step

            # Q-network loss: if PER is enabled, apply importance-sampling weights
            td_errors = (q_a - target).abs()  # TD error magnitude
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer) and 'weights' in q_batch:
                # PER importance-sampling weights
                weights = q_batch['weights']
                td_loss = (weights * F.smooth_l1_loss(q_a, target, reduction='none')).mean()
            else:
                # Uniform sampling: no weights
                td_loss = F.smooth_l1_loss(q_a, target)
            
            self.q_optimizer.zero_grad()
            td_loss.backward()
            nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)
            self.q_optimizer.step()
            q_loss_val = td_loss.item()
            
            # PER priorities: use absolute TD errors from the Q-network batch.
            # Note: q_a - target can be positive/negative; priorities must be non-negative.
            with torch.no_grad():
                q_mc_errors = td_errors.cpu().numpy()

            # Periodic target update
            self.q_train_step += 1
            if self.q_train_step % self.q_target_update == 0:
                self.q_target.load_state_dict(self.q_net.state_dict())
        else:
            q_mc_errors = None
        
        # Update priorities for PER (Q TD errors only)
        # Note: PER priority updates use the Q-network batch (q_batch), not the FPVR(SR) batch (fpvr_batch).
        if (self.config["prioritized_replay"] and q_mc_errors is not None 
            and q_batch is not None and isinstance(self.replay_buffer, PrioritizedReplayBuffer) 
            and 'tree_idxs' in q_batch):
            self.replay_buffer.update_priorities(q_batch['tree_idxs'], q_mc_errors)
        
        out = {
            'sr_loss': sr_loss.item(),
            'total_loss': total_loss.item(),
        }
        if q_loss_val is not None:
            out['q_loss'] = q_loss_val
        return out
    
    def save(self, path, extra_state=None):
        """Save checkpoint
        
        Args:
            path: checkpoint file path
            extra_state: dict with additional training state (e.g., episode_count, positive_reward_count)
        """
        save_optimizers = bool(self.config.get("save_optimizers", True))
        save_q_target = bool(self.config.get("save_q_target", True))
        checkpoint_fp16 = bool(self.config.get("checkpoint_fp16", False))

        if checkpoint_fp16 and save_optimizers:
            # Optimizer state is dtype-sensitive; fp16 checkpoints are intended as lightweight weights.
            print("[Checkpoint] checkpoint_fp16=True -> forcing save_optimizers=False (weights-only).")
            save_optimizers = False

        def _maybe_fp16_state_dict(sd: dict):
            if not checkpoint_fp16:
                return sd
            out = {}
            for k, v in sd.items():
                if torch.is_tensor(v) and v.dtype in (torch.float32, torch.float64):
                    out[k] = v.half()
                else:
                    out[k] = v
            return out

        checkpoint = {
            'network': _maybe_fp16_state_dict(self.network.state_dict()),
            'q_net': _maybe_fp16_state_dict(self.q_net.state_dict()) if self.q_net is not None else None,
            'q_target': _maybe_fp16_state_dict(self.q_target.state_dict()) if (save_q_target and self.q_target is not None) else None,
            'q_train_step': self.q_train_step,
            'c_vec': self.c_vec.detach().cpu(),
            'running_mean': self.running_mean.detach().cpu(),
            'running_cov': self.running_cov.detach().cpu(),
            'whitening_matrix': self.whitening_matrix.detach().cpu(),
            'whitening_update_count': self.whitening_update_count,
            'train_step_count': self.train_step_count,
            'env_step_count': self.env_step_count,
            'config': self.config,
        }

        if save_optimizers:
            checkpoint['optimizer'] = self.optimizer.state_dict()
            checkpoint['q_optimizer'] = self.q_optimizer.state_dict() if self.q_optimizer is not None else None
        
        # Add extra training state
        if extra_state:
            checkpoint.update(extra_state)

        # Atomic save: write to temp then replace.
        # This prevents corrupted checkpoints if the process is interrupted mid-write.
        tmp_path = str(path) + ".tmp"
        try:
            try:
                torch.save(checkpoint, tmp_path)
            except RuntimeError as err:
                print(f"[Warning] torch.save failed with new zip serialization ({err}). Retrying with legacy format.")
                torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
            os.replace(tmp_path, path)
        finally:
            # Best-effort cleanup if something went wrong before replace.
            try:
                if os.path.exists(tmp_path):
                    os.remove(tmp_path)
            except Exception:
                pass
        print(f"[Checkpoint] Saved to {path}")
    
    def load(self, path):
        """Load checkpoint"""
        # Be compatible across PyTorch versions / checkpoint formats.
        # Some versions do not accept `weights_only` for legacy pickle checkpoints.
        try:
            checkpoint = torch.load(path, map_location=self.device, weights_only=False)
        except TypeError as e:
            if "weights_only" in str(e):
                checkpoint = torch.load(path, map_location=self.device)
            else:
                raise
        
        # Restore config first (may contain q_net_type); keep backward compatibility.
        ckpt_config = checkpoint.get("config", {}) or {}
        ckpt_q_net_type = str(ckpt_config.get("q_net_type", "nature")).lower()
        self.config.update(ckpt_config)
        self.config["q_net_type"] = ckpt_q_net_type
        self.q_net_type = ckpt_q_net_type

        self.network.load_state_dict(checkpoint['network'])
        if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print("[Checkpoint] Optimizer state not found; optimizer will be re-initialized (momentum/Adam moments reset).")
        # DDQN components: always enabled; load if present
        if 'q_net' in checkpoint and checkpoint['q_net'] is not None:
            # Rebuild Q nets to match checkpoint architecture before loading weights.
            self.q_net = make_q_network(self.state_shape, self.n_actions, self.q_net_type).to(self.device)
            self.q_target = make_q_network(self.state_shape, self.n_actions, self.q_net_type).to(self.device)
            self.q_net.load_state_dict(checkpoint['q_net'])
        if 'q_target' in checkpoint and checkpoint['q_target'] is not None:
            self.q_target.load_state_dict(checkpoint['q_target'])
        else:
            self.q_target.load_state_dict(self.q_net.state_dict())
        # Re-create Q optimizer on the (possibly rebuilt) q_net parameters, then load state if present.
        q_lr = float(self.config.get("q_lr", 2.5e-4))
        self.q_optimizer = torch.optim.RMSprop(
            self.q_net.parameters(), lr=q_lr, momentum=0.0, alpha=0.95, eps=0.01
        )
        if 'q_optimizer' in checkpoint and checkpoint['q_optimizer'] is not None:
            self.q_optimizer.load_state_dict(checkpoint['q_optimizer'])
        else:
            print("[Checkpoint] Q optimizer state not found; RMSprop state will be reset.")
        self.q_train_step = int(checkpoint.get('q_train_step', 0))
        self.c_vec = checkpoint['c_vec'].to(self.device)
        self.running_mean = checkpoint['running_mean'].to(self.device)
        self.running_cov = checkpoint['running_cov'].to(self.device)
        self.whitening_matrix = checkpoint['whitening_matrix'].to(self.device)
        self.whitening_update_count = checkpoint['whitening_update_count']
        self.train_step_count = checkpoint['train_step_count']
        self.env_step_count = checkpoint.get('env_step_count', 0)
        
        print(f"[Checkpoint] Loaded from {path}")
        print(f"  → Train steps: {self.train_step_count}")
        print(f"  → Environment steps: {self.env_step_count}")
        print(f"  → Whitening updates: {self.whitening_update_count}")
        
        return checkpoint.get('config', {})

