# This code is an anonymized version of Gaze-on-the-Prize code for submission purpose.
# Full codebase will be released upon paper acceptance.

from collections import defaultdict
import os
import random
import time
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
import tqdm
from torch.utils.tensorboard import SummaryWriter
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import imageio
import faiss
import math

# ManiSkill specific imports
import mani_skill.envs
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

@dataclass
class Args:
    exp_name: Optional[str] = None
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = True
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Gaze-on-the-Prize-SAC"
    """the wandb's project name"""
    wandb_entity: Optional[str] = None
    """the entity (team) of wandb's project"""
    wandb_group: str = "SAC"
    """the group of the run for wandb"""
    capture_video: bool = True
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = True
    """whether to save model into the `runs/{run_name}` folder"""
    save_trajectory: bool = False
    """whether to save trajectory data into the `videos` folder"""
    evaluate: bool = False
    """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories"""
    checkpoint: Optional[str] = None
    """path to a pretrained checkpoint file to start evaluation/training from"""
    render_mode: str = "all"
    """the environment rendering mode"""
    obs_mode: str = "rgb"
    """the observation mode to use"""
    env_vectorization: str = "gpu"
    """the type of environment vectorization to use"""
    log_freq: int = 1_000
    """logging frequency in terms of environment steps"""

    # Algorithm specific arguments (SAC)
    env_id: str = "PickCube-v1"
    """the id of the environment"""
    include_state: bool = True
    """whether to include state information in observations"""
    total_timesteps: int = 1_000_000
    """total timesteps of the experiments"""
    buffer_size: int = 300_000
    """the replay memory buffer size"""
    buffer_device: str = "cuda"
    """where the replay buffer is stored. Can be 'cpu' or 'cuda' for GPU"""
    gamma: float = 0.8
    """the discount factor gamma"""
    tau: float = 0.01
    """target smoothing coefficient"""
    batch_size: int = 512
    """the batch size of sample from the replay memory"""
    learning_starts: int = 4_000
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 3e-4
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 1
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1
    """the frequency of updates for the target nerworks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""
    training_freq: int = 64
    """training frequency (in steps)"""
    utd: float = 0.5
    """update to data ratio"""
    num_envs: int = 32
    """the number of parallel environments"""
    num_eval_envs: int = 50
    """the number of parallel evaluation environments"""
    partial_reset: bool = False
    """whether to let parallel environments reset upon termination instead of truncation"""
    eval_partial_reset: bool = False
    """whether to let parallel evaluation environments reset upon termination instead of truncation"""
    num_steps: int = 50
    """the number of steps to run in each environment per rollout"""
    num_eval_steps: int = 50
    """the number of steps to run in each evaluation environment during evaluation"""
    reconfiguration_freq: Optional[int] = None
    """how often to reconfigure the environment during training"""
    eval_reconfiguration_freq: Optional[int] = 1
    """for benchmarking purposes we want to reconfigure the eval environment each reset to ensure objects are randomized in some tasks"""
    control_mode: Optional[str] = "pd_ee_delta_pos"
    """the control mode to use for the environment"""
    eval_freq: int = 10_000
    """evaluation frequency in terms of iterations"""
    save_train_video_freq: Optional[int] = None
    """frequency to save training videos in terms of iterations"""
    bootstrap_at_done: str = "always"
    """the bootstrap method to use when a done signal is received. Can be 'always' or 'never'"""
    camera_width: Optional[int] = 64
    """the width of the camera image. If none it will use the default the environment specifies"""
    camera_height: Optional[int] = 64
    """the height of the camera image. If none it will use the default the environment specifies."""

    # Contrastive learning arguments
    use_contrastive: bool = True
    """whether to use contrastive attention learning"""
    lambda_contrast: float = 0.1
    """weight for contrastive loss"""
    contrast_top_k: int = 16
    """number of similar states to consider for triplet mining"""
    contrast_margin: float = 0.5
    """margin for return difference in triplet selection"""
    contrastive_update_freq: int = 1
    """only compute contrastive loss every N gradient updates"""
    
    # contrastive buffer arguments
    contrastive_buffer_size: int = 50_000
    """Maximum size of contrastive contrastive buffer"""
    use_faiss: bool = True
    """Use FAISS for efficient similarity search"""
    contrast_batch_size: int = 1024
    """Number of anchors to sample for triplet mining"""

    # Attention arguments
    attention_type: str = "foveal"
    """Type of attention: 'foveal' or 'patch'"""
    lambda_spread: float = 0.1
    """Weight for spread regularization in attention"""

    # to be filled in runtime
    grad_steps_per_iteration: int = 0
    """the number of gradient updates per iteration"""
    steps_per_env: int = 0
    """the number of steps each parallel env takes per iteration"""

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class PatchAttention(nn.Module):
    """Simple patch-based attention using a single 1x1 conv for attention scores."""
    def __init__(self, in_channels):
        super().__init__()
        
        # Single 1x1 conv to generate attention logits for each patch
        self.attention_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        
        # Initialize with small weights
        nn.init.xavier_uniform_(self.attention_conv.weight, gain=0.1)
        nn.init.zeros_(self.attention_conv.bias)
    
    def forward(self, features):
        B, C, H, W = features.shape
        
        # Generate attention logits for each spatial location
        attention_logits = self.attention_conv(features)
        attention_logits_flat = attention_logits.view(B, -1)
        attention_weights = F.softmax(attention_logits_flat, dim=-1)
        attention_map = attention_weights.view(B, 1, H, W)
        # For contrastive loss, use the flattened weights as parameters
        attention_params = attention_weights
        
        return attention_map, attention_params

class FovealAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # Using a Conv2d followed by MaxPool2d to reduce feature map size
        self.param_net = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            nn.Flatten(),
            # Adjust the Linear layer's input size for the new, smaller feature map
            nn.Linear(32 * 2 * 2, 5) 
        )
        # Cache coordinate grids - will be set during first forward pass
        self.register_buffer('x_grid', None)
        self.register_buffer('y_grid', None)
        self.cached_size = None
        
    def _get_coordinate_grids(self, H, W, device):
        """Get or create coordinate grids."""
        if self.cached_size != (H, W) or self.x_grid is None:
            y_coords = torch.linspace(0, 1, H, device=device)
            x_coords = torch.linspace(0, 1, W, device=device)
            y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
            self.register_buffer('x_grid', x_grid)
            self.register_buffer('y_grid', y_grid)
            self.cached_size = (H, W)
        return self.y_grid, self.x_grid
        
    def forward(self, features):
        B, C, H, W = features.shape
        params = self.param_net(features)
        
        # Parse and constrain parameters
        cx = params[:, 0].sigmoid()  # [0, 1]
        cy = params[:, 1].sigmoid()  # [0, 1]
        sx = F.softplus(params[:, 2]) / 10.0 + 1e-4
        sy = F.softplus(params[:, 3]) / 10.0 + 1e-4
        theta = params[:, 4].tanh() * np.pi  # [-π, π]
        
        # Get coordinate grids
        y_grid, x_grid = self._get_coordinate_grids(H, W, features.device)
        
        # Vectorized Gaussian computation
        cos_t = torch.cos(theta).view(B, 1, 1) 
        sin_t = torch.sin(theta).view(B, 1, 1) 
        
        # Broadcast operations across batch
        dx = x_grid.unsqueeze(0) - cx.view(B, 1, 1)  
        dy = y_grid.unsqueeze(0) - cy.view(B, 1, 1)  
        
        # Vectorized rotation
        x_rot = cos_t * dx + sin_t * dy 
        y_rot = -sin_t * dx + cos_t * dy 
        
        # Vectorized Gaussian computation
        sx_expanded = sx.view(B, 1, 1) 
        sy_expanded = sy.view(B, 1, 1) 
        gaussian = torch.exp(-0.5 * (x_rot**2 / sx_expanded**2 + y_rot**2 / sy_expanded**2))
        
        # Add channel dimension and normalize
        gaussian = gaussian.unsqueeze(1)
        gaussian = gaussian / (gaussian.sum(dim=[2, 3], keepdim=True) + 1e-8)
        
        # Return parameters for contrastive loss
        return gaussian, torch.stack([cx, cy, sx, sy, theta], dim=1)

class DictArray(object):
    def __init__(self, buffer_shape, element_space, data_dict=None, device=None):
        self.buffer_shape = buffer_shape
        if data_dict:
            self.data = data_dict
        else:
            assert isinstance(element_space, gym.spaces.dict.Dict)
            self.data = {}
            for k, v in element_space.items():
                if isinstance(v, gym.spaces.dict.Dict):
                    self.data[k] = DictArray(buffer_shape, v, device=device)
                else:
                    dtype = (torch.float32 if v.dtype in (np.float32, np.float64) else
                            torch.uint8 if v.dtype == np.uint8 else
                            torch.int16 if v.dtype == np.int16 else
                            torch.int32 if v.dtype == np.int32 else
                            v.dtype)
                    self.data[k] = torch.zeros(buffer_shape + v.shape, dtype=dtype, device=device)

    def keys(self):
        return self.data.keys()

    def __getitem__(self, index):
        if isinstance(index, str):
            return self.data[index]
        return {
            k: v[index] for k, v in self.data.items()
        }

    def __setitem__(self, index, value):
        if isinstance(index, str):
            self.data[index] = value
        for k, v in value.items():
            self.data[k][index] = v

    @property
    def shape(self):
        return self.buffer_shape

    def reshape(self, shape):
        t = len(self.buffer_shape)
        new_dict = {}
        for k,v in self.data.items():
            if isinstance(v, DictArray):
                new_dict[k] = v.reshape(shape)
            else:
                new_dict[k] = v.reshape(shape + v.shape[t:])
        new_buffer_shape = next(iter(new_dict.values())).shape[:len(shape)]
        return DictArray(new_buffer_shape, None, data_dict=new_dict)

class ContrastiveBuffer:
    """contrastive buffer for storing CNN features and embeddings with FAISS-based efficient search."""
    
    def __init__(self, cnn_feature_shape: Tuple[int, int, int], embedding_dim: int, 
                 max_size: int = 100000, device: str = 'cuda'):
        self.cnn_feature_shape = cnn_feature_shape  # (C, H, W)
        self.embedding_dim = embedding_dim
        self.max_size = max_size
        self.device = device
        
        # Pre-allocate storage for efficiency
        # CNN features for gradient flow through attention
        self.cnn_features = torch.zeros((max_size,) + cnn_feature_shape, device=device)
        # Flattened embeddings for FAISS search
        self.embeddings = torch.zeros((max_size, embedding_dim), device=device)
        # Returns for triplet selection
        self.returns = torch.zeros(max_size, device=device)
        
        self.current_idx = 0
        self.is_full = False
        
        # FAISS index for cosine similarity (using inner product on normalized vectors)
        self.index = faiss.IndexFlatIP(embedding_dim)
        
        # Try to use GPU if available
        self.use_gpu = False
        device_str = str(device)
        if device_str == 'cuda' or device_str.startswith('cuda:'):
            try:
                # Check if CUDA is available through PyTorch (already imported at top)
                if torch.cuda.is_available() and faiss.get_num_gpus() > 0:
                    # Try multiple GPU initialization approaches
                    gpu_success = False
                    
                    # Try StandardGpuResources (should work with FAISS 1.8.0)
                    try:
                        res = faiss.StandardGpuResources()
                        # Configure GPU memory (512 MB temp memory)
                        res.setTempMemory(512 * 1024 * 1024)
                        
                        # Try to use the first available GPU (usually GPU 0, but adapt to CUDA_VISIBLE_DEVICES)
                        gpu_id = 0
                        self.gpu_resources = res  # Keep reference to prevent garbage collection
                        self.gpu_index = faiss.index_cpu_to_gpu(res, gpu_id, self.index)
                        self.index = self.gpu_index
                        gpu_success = True
                        
                        # Get the actual GPU device name
                        cuda_device = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown"
                        print(f"Using FAISS GPU with StandardGpuResources (GPU {gpu_id}: {cuda_device})")
                    except Exception as e:
                        print(f"StandardGpuResources with GPU 0 failed: {e}")
                        
                        
                        # Fallback: try index_cpu_to_all_gpus
                        try:
                            self.gpu_index = faiss.index_cpu_to_all_gpus(self.index)
                            self.index = self.gpu_index
                            gpu_success = True
                            print(f"Using FAISS GPU with index_cpu_to_all_gpus")
                        except Exception as e2:
                            print(f"index_cpu_to_all_gpus also failed: {e2}")
                    
                    if gpu_success:
                        self.use_gpu = True
                        print(f"FAISS GPU acceleration enabled (detected {faiss.get_num_gpus()} GPUs)")
                    else:
                        print("All GPU methods failed, using CPU")
                else:
                    if not torch.cuda.is_available():
                        print("PyTorch CUDA is not available, using FAISS CPU")
                    else:
                        print(f"FAISS reports {faiss.get_num_gpus()} GPUs, using CPU")
            except Exception as e:
                print(f"Failed to initialize FAISS GPU, falling back to CPU: {e}")
                self.use_gpu = False
        else:
            print(f"Using FAISS CPU (device is {device})")
        
        # Track updates for periodic index rebuilding
        self.updates_since_rebuild = 0
        self.rebuild_frequency = 10
    
    def add_batch(self, new_cnn_features: torch.Tensor, new_embeddings: torch.Tensor, 
                  new_returns: torch.Tensor):
        """Add a batch of CNN features, embeddings and returns to the buffer with FIFO replacement.
        
        Args:
            new_cnn_features: (batch_size, C, H, W) CNN features before attention
            new_embeddings: (batch_size, embedding_dim) flattened features for FAISS
            new_returns: (batch_size,) return values
        """
        batch_size = new_cnn_features.shape[0]
        
        # Normalize embeddings for cosine similarity
        new_embeddings_norm = F.normalize(new_embeddings, dim=1)
        
        if self.current_idx + batch_size <= self.max_size:
            # Simple case: no wrap-around
            self.cnn_features[self.current_idx:self.current_idx + batch_size] = new_cnn_features
            self.embeddings[self.current_idx:self.current_idx + batch_size] = new_embeddings_norm
            self.returns[self.current_idx:self.current_idx + batch_size] = new_returns
            
            # Note: FAISS index will be updated by periodic _rebuild_index()
            
            self.current_idx += batch_size
            if self.current_idx >= self.max_size:
                self.is_full = True
                self.current_idx = self.current_idx % self.max_size
                print(f"Buffer is now full, wrapping around. Current idx: {self.current_idx}")
        else:
            # Handle wrap-around case
            first_part = self.max_size - self.current_idx
            self.cnn_features[self.current_idx:] = new_cnn_features[:first_part]
            self.embeddings[self.current_idx:] = new_embeddings_norm[:first_part]
            self.returns[self.current_idx:] = new_returns[:first_part]
            
            second_part = batch_size - first_part
            self.cnn_features[:second_part] = new_cnn_features[first_part:]
            self.embeddings[:second_part] = new_embeddings_norm[first_part:]
            self.returns[:second_part] = new_returns[first_part:]
            
            self.is_full = True
            self.current_idx = second_part
        
        # Track updates and rebuild index periodically
        self.updates_since_rebuild += 1
        if self.updates_since_rebuild >= self.rebuild_frequency or (self.is_full and self.current_idx < batch_size):
            self._rebuild_index()
            self.updates_since_rebuild = 0
    
    def _rebuild_index(self):
        """Rebuild the FAISS index with current buffer contents."""
        size = self.max_size if self.is_full else self.current_idx
        if size > 0:
            embeddings_np = self.embeddings[:size].cpu().numpy()
            
            if self.use_gpu:
                # For GPU index, need to recreate properly
                try:
                    # Reset and rebuild GPU index
                    self.index.reset()
                    self.index.add(embeddings_np)
                    # print(f"Rebuilt FAISS index with {size} vectors (GPU)")
                except Exception as e:
                    print(f"GPU rebuild failed: {e}, recreating GPU index")
                    # Recreate GPU index from scratch
                    cpu_index = faiss.IndexFlatIP(self.embedding_dim)
                    cpu_index.add(embeddings_np)
                    self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index)
                    print(f"Rebuilt FAISS index with {size} vectors (GPU - recreated)")
            else:
                # For CPU index, reset and add
                self.index.reset()
                self.index.add(embeddings_np)
                # print(f"Rebuilt FAISS index with {size} vectors (CPU)")
    
    def search_neighbors(self, query_embeddings: torch.Tensor, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Search k nearest neighbors for query embeddings.
        
        Returns:
            distances: (num_queries, k) tensor of distances
            indices: (num_queries, k) tensor of indices
        """
        # Normalize queries
        query_embeddings = F.normalize(query_embeddings, dim=1)
        
        # Adjust k if buffer has fewer samples
        actual_size = self.max_size if self.is_full else self.current_idx
        k_adjusted = min(k, actual_size)
        
        if k_adjusted == 0:
            # Return empty results if buffer is empty
            return (torch.empty(query_embeddings.shape[0], 0, device=self.device),
                    torch.empty(query_embeddings.shape[0], 0, dtype=torch.long, device=self.device))
        
        # Search in FAISS
        queries_np = query_embeddings.cpu().numpy()
        distances, indices = self.index.search(queries_np, k_adjusted)
        
        # Convert to torch and ensure indices are valid
        distances_torch = torch.from_numpy(distances).to(self.device)
        indices_torch = torch.from_numpy(indices).to(self.device)
        
        # FAISS might return -1 for no match, replace with valid indices
        indices_torch = torch.clamp(indices_torch, min=0, max=actual_size-1)
        
        return distances_torch, indices_torch
    
    def get_samples(self, indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Retrieve CNN features, embeddings and returns for given indices."""
        # Ensure indices are within valid range
        max_valid_idx = self.max_size if self.is_full else self.current_idx
        valid_mask = (indices >= 0) & (indices < max_valid_idx)
        valid_indices = indices[valid_mask]
        
        if len(valid_indices) == 0:
            # Return empty tensors if no valid indices
            return (torch.empty(0, *self.cnn_feature_shape, device=self.device),
                    torch.empty(0, self.embedding_dim, device=self.device),
                    torch.empty(0, device=self.device))
        
        return self.cnn_features[valid_indices], self.embeddings[valid_indices], self.returns[valid_indices]
    
    def sample_anchors(self, n_anchors: int) -> torch.Tensor:
        """Randomly sample anchor indices from the buffer."""
        size = self.max_size if self.is_full else self.current_idx
        if size == 0:
            return torch.tensor([], dtype=torch.long, device=self.device)
        
        # Sample without replacement
        perm = torch.randperm(size, device=self.device)[:n_anchors]
        return perm
    
    @property
    def size(self) -> int:
        """Current number of samples in the buffer."""
        return self.max_size if self.is_full else self.current_idx

@dataclass
class ReplayBufferSample:
    obs: torch.Tensor
    next_obs: torch.Tensor
    actions: torch.Tensor
    rewards: torch.Tensor
    dones: torch.Tensor

class ReplayBuffer:
    def __init__(self, env, num_envs: int, buffer_size: int, storage_device: torch.device, sample_device: torch.device):
        self.buffer_size = buffer_size
        self.pos = 0
        self.full = False
        self.num_envs = num_envs
        self.storage_device = storage_device
        self.sample_device = sample_device
        self.per_env_buffer_size = buffer_size // num_envs
        # note 128x128x3 RGB data with replay buffer size 100_000 takes up around 4.7GB of GPU memory
        # 32 parallel envs with rendering uses up around 2.2GB of GPU memory.
        self.obs = DictArray((self.per_env_buffer_size, num_envs), env.single_observation_space, device=storage_device)
        # TODO (stao): optimize final observation storage
        self.next_obs = DictArray((self.per_env_buffer_size, num_envs), env.single_observation_space, device=storage_device)
        self.actions = torch.zeros((self.per_env_buffer_size, num_envs) + env.single_action_space.shape, device=storage_device)
        self.logprobs = torch.zeros((self.per_env_buffer_size, num_envs), device=storage_device)
        self.rewards = torch.zeros((self.per_env_buffer_size, num_envs), device=storage_device)
        self.dones = torch.zeros((self.per_env_buffer_size, num_envs), device=storage_device)
        self.values = torch.zeros((self.per_env_buffer_size, num_envs), device=storage_device)

    def add(self, obs: torch.Tensor, next_obs: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor):
        if self.storage_device == torch.device("cpu"):
            obs = {k: v.cpu() for k, v in obs.items()}
            next_obs = {k: v.cpu() for k, v in next_obs.items()}
            action = action.cpu()
            reward = reward.cpu()
            done = done.cpu()

        self.obs[self.pos] = obs
        self.next_obs[self.pos] = next_obs

        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.dones[self.pos] = done

        self.pos += 1
        if self.pos == self.per_env_buffer_size:
            self.full = True
            self.pos = 0
            
    def sample(self, batch_size: int):
        if self.full:
            batch_inds = torch.randint(0, self.per_env_buffer_size, size=(batch_size, ))
        else:
            batch_inds = torch.randint(0, self.pos, size=(batch_size, ))
        env_inds = torch.randint(0, self.num_envs, size=(batch_size, ))
        obs_sample = self.obs[batch_inds, env_inds]
        next_obs_sample = self.next_obs[batch_inds, env_inds]
        obs_sample = {k: v.to(self.sample_device) for k, v in obs_sample.items()}
        next_obs_sample = {k: v.to(self.sample_device) for k, v in next_obs_sample.items()}
        return ReplayBufferSample(
            obs=obs_sample,
            next_obs=next_obs_sample,
            actions=self.actions[batch_inds, env_inds].to(self.sample_device),
            rewards=self.rewards[batch_inds, env_inds].to(self.sample_device),
            dones=self.dones[batch_inds, env_inds].to(self.sample_device)
        )

class PlainConv(nn.Module):
    """PlainConv encoder with attention mechanism."""
    def __init__(self, sample_obs, args):
        super().__init__()
        
        self.out_features = 0
        feature_size = 256
        in_channels = sample_obs["rgb"].shape[-1]
        
        # Store attention type
        self.attention_type = args.attention_type
        
        # 1. Define the CNN backbone using PlainConv architecture
        # Get image size from sample observation
        image_size = sample_obs["rgb"].shape[1:3]  # (H, W)
        
        self.cnn_backbone = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1, bias=True), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(4, 4) if image_size[0] == 128 and image_size[1] == 128 else nn.MaxPool2d(2, 2),  # [32, 32] for 128x128 or [32, 32] for 64x64
            nn.Conv2d(16, 32, 3, padding=1, bias=True), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [16, 16]
            nn.Conv2d(32, 64, 3, padding=1, bias=True), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [8, 8]
            nn.Conv2d(64, 64, 3, padding=1, bias=True), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [4, 4]
            nn.Conv2d(64, 64, 1, padding=0, bias=True), 
            nn.ReLU(inplace=True),
        )
        
        # 2. Define the attention head based on type
        if args.attention_type == "patch":
            self.attention_head = PatchAttention(64)
        elif args.attention_type == "foveal":
            self.attention_head = FovealAttention(64)
        else:
            raise ValueError(f"Unknown attention type: {args.attention_type}")
        
        # 3. Define fully-connected layers
        with torch.no_grad():
            dummy_input = sample_obs["rgb"].float().permute(0,3,1,2).cpu()
            n_flatten = self.cnn_backbone(dummy_input).flatten(1).shape[1]

        # Single feature pathway: processes attention-weighted features
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n_flatten, feature_size), 
            nn.ReLU()
        )
        
        self.out_features = feature_size
        
        # 4. Define the state feature extractor (if used)
        if "state" in sample_obs:
            state_size = sample_obs["state"].shape[-1]
            self.state_extractor = nn.Linear(state_size, 256)
            self.out_features += 256
        else:
            self.state_extractor = None

    def forward(self, observations, return_attention=False, return_cnn_features=False):
        """Forward pass through encoder with attention."""
        rgb_obs = observations["rgb"].float().permute(0, 3, 1, 2) / 255
        cnn_features = self.cnn_backbone(rgb_obs)
        attn_map, attn_params = self.attention_head(cnn_features)
        
        # Apply attention weighting (single pathway)
        weighted_features = cnn_features * attn_map
        
        # Process through FC
        features = self.fc(weighted_features)
        
        # Add state features if present
        if self.state_extractor:
            state_features = self.state_extractor(observations["state"])
            features = torch.cat([features, state_features], dim=1)
        
        # Return based on requested outputs
        if return_attention and return_cnn_features:
            return features, attn_map, attn_params, cnn_features
        elif return_attention:
            return features, attn_map, attn_params
        elif return_cnn_features:
            return features, cnn_features
        return features

# SAC Actor and Critic Networks
LOG_STD_MAX = 2
LOG_STD_MIN = -5

class SoftQNetwork(nn.Module):
    """Q-network for SAC with attention-based encoder."""
    def __init__(self, envs, encoder: PlainConv):
        super().__init__()
        self.encoder = encoder  # Shared encoder with attention
        action_dim = np.prod(envs.single_action_space.shape)
        # Q-network MLP
        self.mlp = nn.Sequential(
            nn.Linear(encoder.out_features + action_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, obs, action, visual_feature=None, detach_encoder=False):
        if visual_feature is None:
            visual_feature = self.encoder(obs)
        if detach_encoder:
            visual_feature = visual_feature.detach()
        x = torch.cat([visual_feature, action], dim=1)
        return self.mlp(x)

class Actor(nn.Module):
    """Stochastic actor for SAC with attention-based encoder."""
    def __init__(self, envs, encoder: PlainConv):
        super().__init__()
        self.encoder = encoder  # Shared encoder with attention
        action_dim = np.prod(envs.single_action_space.shape)
        
        # Actor MLP
        self.mlp = nn.Sequential(
            nn.Linear(encoder.out_features, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        self.fc_mean = nn.Linear(256, action_dim)
        self.fc_logstd = nn.Linear(256, action_dim)
        
        # Action rescaling
        self.action_scale = torch.FloatTensor((envs.single_action_space.high - envs.single_action_space.low) / 2.0)
        self.action_bias = torch.FloatTensor((envs.single_action_space.high + envs.single_action_space.low) / 2.0)

    def get_feature(self, obs, detach_encoder=False, return_attention=False):
        if return_attention:
            visual_feature, attn_map, attn_params, cnn_features = self.encoder(obs, return_attention=True, return_cnn_features=True)
            if detach_encoder:
                visual_feature = visual_feature.detach()
            x = self.mlp(visual_feature)
            return x, visual_feature, attn_map, attn_params, cnn_features
        else:
            visual_feature = self.encoder(obs)
            if detach_encoder:
                visual_feature = visual_feature.detach()
            x = self.mlp(visual_feature)
            return x, visual_feature

    def forward(self, obs, detach_encoder=False, return_attention=False):
        if return_attention:
            x, visual_feature, attn_map, attn_params, cnn_features = self.get_feature(obs, detach_encoder, return_attention=True)
        else:
            x, visual_feature = self.get_feature(obs, detach_encoder)
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        if return_attention:
            return mean, log_std, visual_feature, attn_map, attn_params, cnn_features
        return mean, log_std, visual_feature

    def get_eval_action(self, obs):
        mean, log_std, _ = self(obs)
        action = torch.tanh(mean) * self.action_scale + self.action_bias
        return action

    def get_action(self, obs, detach_encoder=False, return_attention=False):
        if return_attention:
            mean, log_std, visual_feature, attn_map, attn_params, cnn_features = self(obs, detach_encoder, return_attention=True)
        else:
            mean, log_std, visual_feature = self(obs, detach_encoder)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        
        if return_attention:
            return action, log_prob, mean, visual_feature, attn_map, attn_params, cnn_features
        return action, log_prob, mean, visual_feature

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super().to(device)

class Logger:
    def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None:
        self.writer = tensorboard
        self.log_wandb = log_wandb
    def add_scalar(self, tag, scalar_value, step):
        if self.log_wandb:
            wandb.log({tag: scalar_value}, step=step)
        self.writer.add_scalar(tag, scalar_value, step)
    def close(self):
        self.writer.close()

def mine_contrastive_triplets_faiss(buffer: ContrastiveBuffer,
                                   n_anchors: int = 1024,
                                   top_k: int = 50, 
                                   device: str = 'cuda') -> List[Tuple[int, int, int]]:
    """
    Mine anchor-positive-negative triplets using FAISS-based efficient search.
    """
    triplets = []
    
    if buffer.size < 3:
        return triplets
    
    # Sample anchor indices
    anchor_indices = buffer.sample_anchors(min(n_anchors, buffer.size))
    if len(anchor_indices) == 0:
        return triplets
    
    # Get anchor embeddings and returns
    _, anchor_embeddings, anchor_returns = buffer.get_samples(anchor_indices)
    
    # Search for k+1 nearest neighbors (including self)
    k_search = min(top_k + 1, buffer.size)
    _, neighbor_indices = buffer.search_neighbors(anchor_embeddings, k_search)
    
    # Process each anchor
    for i, anchor_idx in enumerate(anchor_indices):
        # Get neighbor indices (excluding self which should be first)
        neighbors = neighbor_indices[i]
        
        # Handle empty neighbors case
        if neighbors.numel() == 0:
            continue
            
        # Find and exclude self
        self_mask = neighbors != anchor_idx
        neighbors = neighbors[self_mask]
        
        if len(neighbors) < 2:
            continue
        
        # Get returns for neighbors with bounds checking
        _, _, neighbor_returns = buffer.get_samples(neighbors)
        
        # Skip if we got no valid samples
        if neighbor_returns.numel() == 0:
            continue
        ret_var = neighbor_returns.var()
        
        # Skip if returns are too similar
        if ret_var < 1e-3:
            continue
        
        # Adaptive margin based on return variance
        adaptive_margin = 0.5 * neighbor_returns.std()
        median_return = neighbor_returns.median()
        
        # Find positive (higher return) and negative (lower return) examples
        pos_mask = neighbor_returns > median_return + adaptive_margin
        neg_mask = neighbor_returns < median_return - adaptive_margin
        
        pos_indices = neighbors[pos_mask]
        neg_indices = neighbors[neg_mask]
        
        # Create triplets if we have both positive and negative examples
        if len(pos_indices) > 0 and len(neg_indices) > 0:
            # Randomly select one positive and one negative
            pos_idx = pos_indices[torch.randint(len(pos_indices), (1,), device=device)].item()
            neg_idx = neg_indices[torch.randint(len(neg_indices), (1,), device=device)].item()
            triplets.append((anchor_idx.item(), pos_idx, neg_idx))
    
    return triplets

def contrastive_loss_faiss(agent: nn.Module,
                                   buffer: ContrastiveBuffer,
                                   triplets: List[Tuple[int, int, int]],
                                   margin: float = 0.5,
                                   lambda_spread: float = 0.1,
                                   device: str = 'cuda') -> Tuple[torch.Tensor, Dict]:
    """Compute feature-level contrastive loss"""
    metrics = {}
    if not triplets:
        return torch.tensor(0.0, device=device), metrics
    
    try:
        # Extract indices for vectorized processing
        anchor_indices, pos_indices, neg_indices = zip(*triplets)
        n_triplets = len(triplets)
        
        # Get stored CNN features (NOT embeddings!) for gradient flow
        all_indices = torch.tensor(anchor_indices + pos_indices + neg_indices, device=device)
        cnn_features_batch, _, _ = buffer.get_samples(all_indices)  # (3*n_triplets, C, H, W)
        
        # Check if we got valid features
        if cnn_features_batch.shape[0] == 0:
            return torch.tensor(0.0, device=device), metrics
        
        # Fresh forward pass through attention head only
        attn_maps, attn_params = agent.feature_net.attention_head(cnn_features_batch)
        
        # Compute attended features (what×where)
        B, C, H, W = cnn_features_batch.shape
        # Attention maps are already normalized for Foveal attention
        p = attn_maps  # (B, 1, H, W)
        
        # Attended features with L2 normalization for cosine similarity
        weighted_features = cnn_features_batch * p
        feat = weighted_features.flatten(1)  # (B, C*H*W)
        feat = F.normalize(feat, dim=1)  # L2 normalize for cosine distance
        
        # Split features into anchor, positive, negative
        fa = feat[:n_triplets]
        fp = feat[n_triplets:2*n_triplets]
        fn = feat[2*n_triplets:]
        
        # Cosine distance function
        def cos_dist(x, y):
            return 1 - (x * y).sum(dim=1)
        
        # Feature-level distances
        dist_pos_feat = cos_dist(fa, fp)
        dist_neg_feat = cos_dist(fa, fn)
        
        # Pure feature-level contrastive loss
        contrastive_loss = F.relu(margin + (dist_pos_feat - dist_neg_feat))
            
        # Extract parameters for spread regularization
        params_anchor = attn_params[:n_triplets]
        params_pos = attn_params[n_triplets:2*n_triplets]
        params_neg = attn_params[2*n_triplets:]
        
        # Spread regularization (only for Foveal attention)
        spread_reg = torch.tensor(0.0, device=device)
        if agent.feature_net.attention_type == "foveal":
            # collect all spreads (σx, σy)
            all_spreads = torch.cat([
                params_anchor[:, 2:4], params_pos[:, 2:4], params_neg[:, 2:4]
            ], dim=0)  # shape (3N, 2)

            # log-space anchor: pull toward a small-but-nonzero target
            sigma_t = 0.1                     # target spread in your normalized coords
            log_sigma_t = math.log(sigma_t + 1e-8)
            log_spreads = torch.log(all_spreads + 1e-8)

            # smooth quadratic around the target in log-space
            spread_reg = ((log_spreads - log_sigma_t) ** 2).mean() * lambda_spread
        
        # Compute final loss
        total_loss = contrastive_loss.mean() + spread_reg
        
        # Collect comprehensive metrics for logging
        with torch.no_grad():
            # Distance differences (key metric - should be positive for good separation)
            dist_diff = dist_neg_feat - dist_pos_feat
            
            # Active fraction (how many triplets violate the margin)
            active_mask = (margin + (dist_pos_feat - dist_neg_feat)) > 0
            
            # Get spread statistics for Foveal attention
            if agent.feature_net.attention_type == "foveal":
                all_spreads = torch.cat([
                    params_anchor[:, 2:4], params_pos[:, 2:4], params_neg[:, 2:4]
                ], dim=0)
                mean_spread_x = all_spreads[:, 0].mean().item()
                mean_spread_y = all_spreads[:, 1].mean().item()
                spread_x_std = all_spreads[:, 0].std().item()
                spread_y_std = all_spreads[:, 1].std().item()
            else:
                mean_spread_x = 0.0
                mean_spread_y = 0.0
                spread_x_std = 0.0
                spread_y_std = 0.0
            
            metrics = {
                # Feature-level metrics
                'mean_dist_pos_feat': dist_pos_feat.mean().item(),
                'mean_dist_neg_feat': dist_neg_feat.mean().item(),
                'dist_difference': dist_diff.mean().item(),
                'dist_diff_std': dist_diff.std().item(),
                'dist_diff_min': dist_diff.min().item(),
                'dist_diff_max': dist_diff.max().item(),
                
                
                # Loss components
                'contrastive_loss': contrastive_loss.mean().item(),
                'spread_reg_loss': spread_reg.item(),
                'active_fraction': active_mask.float().mean().item(),
                
                # Attention-specific parameters
                'mean_spread_x': mean_spread_x,
                'mean_spread_y': mean_spread_y,
                'spread_x_std': spread_x_std,
                'spread_y_std': spread_y_std,
                'attention_type': agent.feature_net.attention_type,
            }

        return total_loss, metrics
        
    except (IndexError, KeyError, RuntimeError):
        # Return zero loss if computation fails
        return torch.tensor(0.0, device=device), metrics

def create_attention_overlay(rgb_frame: np.ndarray, attention_map: torch.Tensor, alpha: float = 0.4) -> np.ndarray:
    if rgb_frame.dtype != np.uint8:
        rgb_frame = (rgb_frame * 255).astype(np.uint8)

    H, W, C = rgb_frame.shape
    num_cameras = C // 3  # Detect number of cameras

    attention = attention_map.squeeze().cpu().numpy()
    attention_resized = cv2.resize(attention, (W, H), interpolation=cv2.INTER_LINEAR)

    attention_norm = (attention_resized - attention_resized.min()) / (attention_resized.max() - attention_resized.min() + 1e-8)

    # Apply colormap
    attention_colored = cm.viridis(attention_norm)[:, :, :3]  # (H, W, 3)
    attention_colored = (attention_colored * 255).astype(np.uint8)

    # Replicate attention overlay for each camera
    if num_cameras > 1:
        attention_colored = np.concatenate([attention_colored] * num_cameras, axis=2)

    # Blend with original image
    overlay = ((1 - alpha) * rgb_frame + alpha * attention_colored).astype(np.uint8)

    return overlay

def save_attention_video(frames: List[np.ndarray], output_path: str, fps: int = 30) -> None:
    """
    Save a list of frames as a video file using imageio.
    """
    if not frames:
        return

    # Check if multi-camera frames
    num_channels = frames[0].shape[2]
    if num_channels > 4:  # Multi-camera case
        num_cameras = num_channels // 3
        processed_frames = []

        for frame in frames:
            H, W, C = frame.shape
            # Split and concatenate cameras horizontally
            cameras = [frame[:, :, i*3:(i+1)*3] for i in range(num_cameras)]
            combined = np.concatenate(cameras, axis=1)  # Side-by-side
            processed_frames.append(combined)

        frames = processed_frames

    # Save video
    imageio.mimsave(output_path, frames, fps=fps, codec='libx264')

if __name__ == "__main__":
    args = tyro.cli(Args)
    args.grad_steps_per_iteration = int(args.training_freq * args.utd)
    args.steps_per_env = args.training_freq // args.num_envs
    if args.exp_name is None:
        args.exp_name = os.path.basename(__file__)[: -len(".py")]
        run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    else:
        run_name = args.exp_name

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    ####### Environment setup #######
    env_kwargs = dict(obs_mode=args.obs_mode, render_mode=args.render_mode, sim_backend="gpu", sensor_configs=dict())
    if args.control_mode is not None:
        env_kwargs["control_mode"] = args.control_mode
    if args.camera_width is not None:
        # this overrides every sensor used for observation generation
        env_kwargs["sensor_configs"]["width"] = args.camera_width
    if args.camera_height is not None:
        env_kwargs["sensor_configs"]["height"] = args.camera_height
    envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, reconfiguration_freq=args.reconfiguration_freq, **env_kwargs)
    eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq, human_render_camera_configs=dict(shader_pack="default"), **env_kwargs)

    # rgbd obs mode returns a dict of data, we flatten it so there is just a rgbd key and state key
    envs = FlattenRGBDObservationWrapper(envs, rgb=True, depth=False, state=args.include_state)
    eval_envs = FlattenRGBDObservationWrapper(eval_envs, rgb=True, depth=False, state=args.include_state)

    if isinstance(envs.action_space, gym.spaces.Dict):
        envs = FlattenActionSpaceWrapper(envs)
        eval_envs = FlattenActionSpaceWrapper(eval_envs)
    if args.capture_video or args.save_trajectory:
        eval_output_dir = f"runs/{run_name}/videos"
        if args.evaluate:
            eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos"
        print(f"Saving eval trajectories/videos to {eval_output_dir}")
        if args.save_train_video_freq is not None:
            save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0
            envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30)
        eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.save_trajectory, save_video=args.capture_video, trajectory_name="trajectory", max_steps_per_video=args.num_eval_steps, video_fps=30)
    envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=not args.partial_reset, record_metrics=True)
    eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=not args.eval_partial_reset, record_metrics=True)
    assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

    max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env)
    logger = None
    if not args.evaluate:
        print("Running training")
        if args.track:
            import wandb
            config = vars(args)
            config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset)
            config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=False)
            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=False,
                config=config,
                name=run_name,
                save_code=True,
                group=args.wandb_group,
                tags=["sac", f"{args.env_id.split('-')[0].split('Level')[0]}"],
            )
        writer = SummaryWriter(f"runs/{run_name}")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )
        logger = Logger(log_wandb=args.track, tensorboard=writer)
    else:
        print("Running evaluation")

    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        env=envs,
        num_envs=args.num_envs,
        buffer_size=args.buffer_size,
        storage_device=torch.device(args.buffer_device),
        sample_device=device
    )

    # TRY NOT TO MODIFY: start the game
    obs, info = envs.reset(seed=args.seed) # in Gymnasium, seed is given to reset() instead of seed()
    eval_obs, _ = eval_envs.reset(seed=args.seed)

    # Create shared encoder and networks
    encoder = PlainConv(sample_obs=obs, args=args).to(device)
    actor = Actor(envs, encoder).to(device)
    qf1 = SoftQNetwork(envs, encoder).to(device)
    qf2 = SoftQNetwork(envs, encoder).to(device)
    qf1_target = SoftQNetwork(envs, encoder).to(device)
    qf2_target = SoftQNetwork(envs, encoder).to(device)
    
    if args.checkpoint is not None:
        ckpt = torch.load(args.checkpoint)
        actor.load_state_dict(ckpt['actor'])
        qf1.load_state_dict(ckpt['qf1'])
        qf2.load_state_dict(ckpt['qf2'])
    
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    
    q_optimizer = optim.Adam(
        list(qf1.mlp.parameters()) +
        list(qf2.mlp.parameters()) +
        list(encoder.parameters()),
        lr=args.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)

    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha

    # Initialize contrastive components
    
    contrastive_buffer = None
    if args.use_contrastive and args.use_faiss:
        # Get CNN feature shape with a dummy forward pass
        with torch.no_grad():
            dummy_features, dummy_attn, dummy_params, dummy_cnn = encoder(obs, return_attention=True, return_cnn_features=True)
            cnn_feature_shape = dummy_cnn.shape[1:]  # (C, H, W)
            embedding_dim = dummy_cnn[0].flatten().shape[0]

        contrastive_buffer = ContrastiveBuffer(
            cnn_feature_shape=cnn_feature_shape,
            embedding_dim=embedding_dim,
            max_size=args.contrastive_buffer_size,
            device=device
        )
        print(f"Initialized Contrastive Buffer with CNN shape={cnn_feature_shape}, "
              f"embedding_dim={embedding_dim}, max_size={args.contrastive_buffer_size}")
        print(f"Using attention type: {args.attention_type}")

    global_step = 0
    global_update = 0
    learning_has_started = False

    global_steps_per_iteration = args.num_envs * (args.steps_per_env)
    pbar = tqdm.tqdm(range(args.total_timesteps))
    cumulative_times = defaultdict(float)

    while global_step < args.total_timesteps:
        if args.eval_freq > 0 and (global_step - args.training_freq) // args.eval_freq < global_step // args.eval_freq:
            # evaluate
            actor.eval()
            stime = time.perf_counter()
            eval_obs, _ = eval_envs.reset()
            eval_metrics = defaultdict(list)
            num_episodes = 0
            for _ in range(args.num_eval_steps):
                with torch.no_grad():
                    eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(actor.get_eval_action(eval_obs))
                    if "final_info" in eval_infos:
                        mask = eval_infos["_final_info"]
                        num_episodes += mask.sum()
                        for k, v in eval_infos["final_info"]["episode"].items():
                            eval_metrics[k].append(v)
            eval_metrics_mean = {}
            for k, v in eval_metrics.items():
                mean = torch.stack(v).float().mean()
                eval_metrics_mean[k] = mean
                if logger is not None:
                    logger.add_scalar(f"eval/{k}", mean, global_step)
            pbar.set_description(
                f"success_once: {eval_metrics_mean.get('success_once', 0):.2f}, "
                f"return: {eval_metrics_mean.get('return', 0):.2f}"
            )
            if logger is not None:
                eval_time = time.perf_counter() - stime
                cumulative_times["eval_time"] += eval_time
                logger.add_scalar("time/eval_time", eval_time, global_step)
            if args.evaluate:
                break
            actor.train()

            if args.save_model:
                model_path = f"runs/{run_name}/ckpt_{global_step}.pt"
                torch.save({
                    'actor': actor.state_dict(),
                    'qf1': qf1_target.state_dict(),
                    'qf2': qf2_target.state_dict(),
                    'log_alpha': log_alpha if args.autotune else None,
                }, model_path)
                print(f"model saved to {model_path}")

        # Collect samples from environments
        rollout_time = time.perf_counter()
        for local_step in range(args.steps_per_env):
            global_step += 1 * args.num_envs

            # ALGO LOGIC: put action logic here
            if not learning_has_started:
                actions = 2 * torch.rand(size=envs.action_space.shape, dtype=torch.float32, device=device) - 1
            else:
                actions, _, _, _, attn_map, attn_params, cnn_features = actor.get_action(obs, return_attention=True)
                actions = actions.detach()
                
                # Add to contrastive buffer periodically
                if args.use_contrastive and args.use_faiss and global_step % args.contrastive_update_freq == 0:
                    with torch.no_grad():
                        # cnn_features shape: (num_envs, C, H, W)
                        embeddings = cnn_features.flatten(1)  # (num_envs, C*H*W)
                        # We don't have returns yet, so we'll add them after episode completion
                        # For now, use placeholder returns (will be updated later)
                        placeholder_returns = torch.zeros(args.num_envs, device=device)
                        contrastive_buffer.add_batch(cnn_features, embeddings, placeholder_returns)

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, rewards, terminations, truncations, infos = envs.step(actions)
            real_next_obs = {k:v.clone() for k, v in next_obs.items()}
            if args.bootstrap_at_done == 'never':
                need_final_obs = torch.ones_like(terminations, dtype=torch.bool)
                stop_bootstrap = truncations | terminations # always stop bootstrap when episode ends
            else:
                if args.bootstrap_at_done == 'always':
                    need_final_obs = truncations | terminations # always need final obs when episode ends
                    stop_bootstrap = torch.zeros_like(terminations, dtype=torch.bool) # never stop bootstrap
                else: # bootstrap at truncated
                    need_final_obs = truncations & (~terminations) # only need final obs when truncated and not terminated
                    stop_bootstrap = terminations # only stop bootstrap when terminated, don't stop when truncated
            if "final_info" in infos:
                final_info = infos["final_info"]
                done_mask = infos["_final_info"]
                for k in real_next_obs.keys():
                    real_next_obs[k][need_final_obs] = infos["final_observation"][k][need_final_obs].clone()
                for k, v in final_info["episode"].items():
                    if logger is not None:
                        logger.add_scalar(f"train/{k}", v[done_mask].float().mean(), global_step)

            rb.add(obs, real_next_obs, actions, rewards, stop_bootstrap)

            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            obs = next_obs
        rollout_time = time.perf_counter() - rollout_time
        cumulative_times["rollout_time"] += rollout_time
        pbar.update(args.num_envs * args.steps_per_env)

        # ALGO LOGIC: training.
        if global_step < args.learning_starts:
            continue

        update_time = time.perf_counter()
        learning_has_started = True
        
        # Determine whether to use contrastive learning
        should_use_contrastive = args.use_contrastive and args.use_faiss and contrastive_buffer.size > 0
        
        # Mine triplets once per iteration (not per gradient step)
        triplets = []
        if should_use_contrastive and global_update % args.contrastive_update_freq == 0:
            triplets = mine_contrastive_triplets_faiss(
                buffer=contrastive_buffer,
                n_anchors=min(args.contrast_batch_size, contrastive_buffer.size // 2),
                top_k=args.contrast_top_k,
                device=device
            )
            if triplets:
                print(f"Step {global_step}: Buffer size={contrastive_buffer.size}, "
                      f"found {len(triplets)} triplets, weight={args.lambda_contrast:.4f}")
        
        for local_update in range(args.grad_steps_per_iteration):
            global_update += 1
            data = rb.sample(args.batch_size)

            # Update Q-networks with contrastive loss
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _, visual_feature = actor.get_action(data.next_obs)
                qf1_next_target = qf1_target(data.next_obs, next_state_actions, visual_feature)
                qf2_next_target = qf2_target(data.next_obs, next_state_actions, visual_feature)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)
                # data.dones is "stop_bootstrap", which is computed earlier according to args.bootstrap_at_done
            
            # Forward pass with attention for contrastive loss
            visual_feature, attn_map, attn_params, cnn_features = encoder(data.obs, return_attention=True, return_cnn_features=True)
            qf1_a_values = qf1(data.obs, data.actions, visual_feature).view(-1)
            qf2_a_values = qf2(data.obs, data.actions, visual_feature).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            
            # Compute contrastive loss if applicable
            contrastive_loss = torch.tensor(0.0, device=device)
            contrastive_metrics = {}
            if triplets and should_use_contrastive:
                # Create a dummy agent-like object for compatibility with contrastive loss function
                class DummyAgent:
                    def __init__(self, encoder):
                        self.feature_net = encoder
                        self.feature_net.attention_type = args.attention_type
                
                dummy_agent = DummyAgent(encoder)
                contrastive_loss, contrastive_metrics = contrastive_loss_faiss(
                    dummy_agent,
                    contrastive_buffer,
                    triplets,
                    margin=args.contrast_margin,
                    lambda_spread=args.lambda_spread,
                    device=device
                )
            
            qf_loss = qf1_loss + qf2_loss + args.lambda_contrast * contrastive_loss

            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()

            # Update the policy network
            if global_update % args.policy_frequency == 0:  # TD 3 Delayed update support
                pi, log_pi, _, visual_feature = actor.get_action(data.obs)
                qf1_pi = qf1(data.obs, pi, visual_feature, detach_encoder=True)
                qf2_pi = qf2(data.obs, pi, visual_feature, detach_encoder=True)
                min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
                actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()

                if args.autotune:
                    with torch.no_grad():
                        _, log_pi, _, _ = actor.get_action(data.obs)
                    alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

                    a_optimizer.zero_grad()
                    alpha_loss.backward()
                    a_optimizer.step()
                    alpha = log_alpha.exp().item()

            # Update the target networks
            if global_update % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
        
        update_time = time.perf_counter() - update_time
        cumulative_times["update_time"] += update_time

        # Log training-related data
        if (global_step - args.training_freq) // args.log_freq < global_step // args.log_freq:
            if logger is not None:
                logger.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                logger.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                logger.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                logger.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                logger.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                if 'actor_loss' in locals():
                    logger.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                logger.add_scalar("losses/alpha", alpha, global_step)
                logger.add_scalar("time/update_time", update_time, global_step)
                logger.add_scalar("time/rollout_time", rollout_time, global_step)
                logger.add_scalar("time/rollout_fps", global_steps_per_iteration / rollout_time, global_step)
                for k, v in cumulative_times.items():
                    logger.add_scalar(f"time/total_{k}", v, global_step)
                logger.add_scalar("time/total_rollout+update_time", cumulative_times["rollout_time"] + cumulative_times["update_time"], global_step)
                if args.autotune and 'alpha_loss' in locals():
                    logger.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)
                
                # Log contrastive metrics
                if args.use_contrastive:
                    if contrastive_loss.item() > 0:
                        logger.add_scalar("losses/contrastive_loss", contrastive_loss.item(), global_step)
                        logger.add_scalar("contrastive/num_triplets", len(triplets), global_step)
                        
                        if contrastive_metrics:
                            for metric_name, metric_value in contrastive_metrics.items():
                                if metric_name == 'attention_type':
                                    continue
                                logger.add_scalar(f"contrastive/{metric_name}", metric_value, global_step)
                    
                    logger.add_scalar("contrastive/current_weight", args.lambda_contrast, global_step)
                
                # No scheduler metrics to log

    if not args.evaluate and args.save_model:
        model_path = f"runs/{run_name}/final_ckpt.pt"
        torch.save({
            'actor': actor.state_dict(),
            'qf1': qf1_target.state_dict(),
            'qf2': qf2_target.state_dict(),
            'log_alpha': log_alpha if args.autotune else None,
        }, model_path)
        print(f"model saved to {model_path}")
        if logger is not None:
            writer.close()
    envs.close()