# This code is an anonymized version of Gaze-on-the-Prize code for submission purpose.
# Full codebase will be released upon paper acceptance.


from collections import defaultdict
import os
import random
import time
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import imageio
import faiss
import math

# ManiSkill specific imports
import mani_skill.envs
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

@dataclass
class Args:
    exp_name: Optional[str] = None
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Gaze-on-the-Prize-PPO"
    """the wandb's project name"""
    wandb_entity: Optional[str] = None
    """the entity (team) of wandb's project"""
    wandb_group: str = "PPO"
    """the group of the run for wandb"""
    capture_video: bool = True
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = True
    """whether to save model into the `runs/{run_name}` folder"""
    evaluate: bool = False
    """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories"""
    checkpoint: Optional[str] = None
    """path to a pretrained checkpoint file to start evaluation/training from"""
    render_mode: str = "all"
    """the environment rendering mode"""

    # Algorithm specific arguments
    env_id: str = "PickCube-v1"
    """the id of the environment"""
    include_state: bool = True
    """whether to include state information in observations"""
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    learning_rate: float = 3e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1024
    """the number of parallel environments"""
    num_eval_envs: int = 50
    """the number of parallel evaluation environments"""
    partial_reset: bool = True
    """whether to let parallel environments reset upon termination instead of truncation"""
    eval_partial_reset: bool = False
    """whether to let parallel evaluation environments reset upon termination instead of truncation"""
    num_steps: int = 16
    """the number of steps to run in each environment per policy rollout"""
    num_eval_steps: int = 50
    """the number of steps to run in each evaluation environment during evaluation"""
    reconfiguration_freq: Optional[int] = None
    """how often to reconfigure the environment during training"""
    eval_reconfiguration_freq: Optional[int] = 1
    """for benchmarking purposes we want to reconfigure the eval environment each reset to ensure objects are randomized in some tasks"""
    control_mode: Optional[str] = "pd_joint_delta_pos"
    """the control mode to use for the environment"""
    anneal_lr: bool = False
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.8
    """the discount factor gamma"""
    gae_lambda: float = 0.9
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 32
    """the number of mini-batches"""
    update_epochs: int = 4
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = False
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.0
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = 0.2
    """the target KL divergence threshold"""
    reward_scale: float = 1.0
    """Scale the reward by this factor"""
    eval_freq: int = 25
    """evaluation frequency in terms of iterations"""
    save_train_video_freq: Optional[int] = None
    """frequency to save training videos in terms of iterations"""
    finite_horizon_gae: bool = False

    # Contrastive learning arguments
    use_contrastive: bool = True
    """whether to use contrastive attention learning"""
    lambda_contrast: float = 0.1
    """weight for contrastive loss"""
    contrast_top_k: int = 16
    """number of similar states to consider for triplet mining"""
    contrast_margin: float = 0.5
    """margin for return difference in triplet selection"""
    contrastive_update_freq: int = 1
    """only compute contrastive loss every N PPO iterations"""
    
    # contrastive buffer arguments
    contrastive_buffer_size: int = 100_000
    """Maximum size of contrastive contrastive buffer"""
    use_faiss: bool = True
    """Use FAISS for efficient similarity search"""
    contrast_batch_size: int = 1024
    """Number of anchors to sample for triplet mining"""

    # Attention arguments
    attention_type: str = "foveal"
    """Type of attention: 'foveal' or 'patch'"""
    lambda_spread: float = 0.1
    """Weight for spread regularization in attention"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class PatchAttention(nn.Module):
    """Simple patch-based attention using a single 1x1 conv for attention scores."""
    def __init__(self, in_channels):
        super().__init__()
        
        # Single 1x1 conv to generate attention logits for each patch
        self.attention_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        
        # Initialize with small weights
        nn.init.xavier_uniform_(self.attention_conv.weight, gain=0.1)
        nn.init.zeros_(self.attention_conv.bias)
    
    def forward(self, features):
        B, C, H, W = features.shape
        
        # Generate attention logits for each spatial location
        attention_logits = self.attention_conv(features)
        attention_logits_flat = attention_logits.view(B, -1)
        attention_weights = F.softmax(attention_logits_flat, dim=-1)
        attention_map = attention_weights.view(B, 1, H, W)
        # For contrastive loss, use the flattened weights as parameters
        attention_params = attention_weights
        
        return attention_map, attention_params

class FovealAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # Using a Conv2d followed by MaxPool2d to reduce feature map size
        self.param_net = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            nn.Flatten(),
            # Adjust the Linear layer's input size for the new, smaller feature map
            nn.Linear(32 * 6 * 6, 5) 
        )
        # Cache coordinate grids - will be set during first forward pass
        self.register_buffer('x_grid', None)
        self.register_buffer('y_grid', None)
        self.cached_size = None
        
    def _get_coordinate_grids(self, H, W, device):
        """Get or create coordinate grids."""
        if self.cached_size != (H, W) or self.x_grid is None:
            y_coords = torch.linspace(0, 1, H, device=device)
            x_coords = torch.linspace(0, 1, W, device=device)
            y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
            self.register_buffer('x_grid', x_grid)
            self.register_buffer('y_grid', y_grid)
            self.cached_size = (H, W)
        return self.y_grid, self.x_grid
        
    def forward(self, features):
        B, C, H, W = features.shape
        params = self.param_net(features)
        
        # Parse and constrain parameters
        cx = params[:, 0].sigmoid()  # [0, 1]
        cy = params[:, 1].sigmoid()  # [0, 1]
        sx = F.softplus(params[:, 2]) / 10.0 + 1e-4
        sy = F.softplus(params[:, 3]) / 10.0 + 1e-4
        theta = params[:, 4].tanh() * np.pi  # [-π, π]
        
        # Get coordinate grids
        y_grid, x_grid = self._get_coordinate_grids(H, W, features.device)
        
        # Vectorized Gaussian computation
        cos_t = torch.cos(theta).view(B, 1, 1)
        sin_t = torch.sin(theta).view(B, 1, 1)
        
        # Broadcast operations across batch
        dx = x_grid.unsqueeze(0) - cx.view(B, 1, 1) 
        dy = y_grid.unsqueeze(0) - cy.view(B, 1, 1)  
        
        # Vectorized rotation
        x_rot = cos_t * dx + sin_t * dy 
        y_rot = -sin_t * dx + cos_t * dy 
        
        # Vectorized Gaussian computation
        sx_expanded = sx.view(B, 1, 1) 
        sy_expanded = sy.view(B, 1, 1) 
        gaussian = torch.exp(-0.5 * (x_rot**2 / sx_expanded**2 + y_rot**2 / sy_expanded**2))
        
        # Add channel dimension and normalize
        gaussian = gaussian.unsqueeze(1)
        gaussian = gaussian / (gaussian.sum(dim=[2, 3], keepdim=True) + 1e-8)
        
        # Return parameters for contrastive loss
        return gaussian, torch.stack([cx, cy, sx, sy, theta], dim=1)

class DictArray(object):
    def __init__(self, buffer_shape, element_space, data_dict=None, device=None):
        self.buffer_shape = buffer_shape
        if data_dict:
            self.data = data_dict
        else:
            assert isinstance(element_space, gym.spaces.dict.Dict)
            self.data = {}
            for k, v in element_space.items():
                if isinstance(v, gym.spaces.dict.Dict):
                    self.data[k] = DictArray(buffer_shape, v, device=device)
                else:
                    dtype = (torch.float32 if v.dtype in (np.float32, np.float64) else
                            torch.uint8 if v.dtype == np.uint8 else
                            torch.int16 if v.dtype == np.int16 else
                            torch.int32 if v.dtype == np.int32 else
                            v.dtype)
                    self.data[k] = torch.zeros(buffer_shape + v.shape, dtype=dtype, device=device)

    def keys(self):
        return self.data.keys()

    def __getitem__(self, index):
        if isinstance(index, str):
            return self.data[index]
        return {
            k: v[index] for k, v in self.data.items()
        }

    def __setitem__(self, index, value):
        if isinstance(index, str):
            self.data[index] = value
        for k, v in value.items():
            self.data[k][index] = v

    @property
    def shape(self):
        return self.buffer_shape

    def reshape(self, shape):
        t = len(self.buffer_shape)
        new_dict = {}
        for k,v in self.data.items():
            if isinstance(v, DictArray):
                new_dict[k] = v.reshape(shape)
            else:
                new_dict[k] = v.reshape(shape + v.shape[t:])
        new_buffer_shape = next(iter(new_dict.values())).shape[:len(shape)]
        return DictArray(new_buffer_shape, None, data_dict=new_dict)

class ContrastiveBuffer:
    """contrastive buffer for storing CNN features and embeddings with FAISS-based efficient search."""
    
    def __init__(self, cnn_feature_shape: Tuple[int, int, int], embedding_dim: int, 
                 max_size: int = 100000, device: str = 'cuda'):
        self.cnn_feature_shape = cnn_feature_shape  # (C, H, W)
        self.embedding_dim = embedding_dim
        self.max_size = max_size
        self.device = device
        
        # Pre-allocate storage
        # CNN features for gradient flow through attention
        self.cnn_features = torch.zeros((max_size,) + cnn_feature_shape, device=device)
        # Flattened embeddings for FAISS search
        self.embeddings = torch.zeros((max_size, embedding_dim), device=device)
        # Returns for triplet selection
        self.returns = torch.zeros(max_size, device=device)
        
        self.current_idx = 0
        self.is_full = False
        
        # FAISS index for cosine similarity (using inner product on normalized vectors)
        self.index = faiss.IndexFlatIP(embedding_dim)
        
        # use GPU if available
        self.use_gpu = False
        device_str = str(device)
        if device_str == 'cuda' or device_str.startswith('cuda:'):
            try:
                # Check if CUDA is available through PyTorch (already imported at top)
                if torch.cuda.is_available() and faiss.get_num_gpus() > 0:
                    # Try multiple GPU initialization approaches
                    gpu_success = False
                    
                    # Try StandardGpuResources (should work with FAISS 1.8.0)
                    try:
                        res = faiss.StandardGpuResources()
                        # Configure GPU memory (512 MB temp memory)
                        res.setTempMemory(512 * 1024 * 1024)
                        
                        # Try to use the first available GPU (adapt to CUDA_VISIBLE_DEVICES)
                        gpu_id = 0
                        self.gpu_resources = res
                        self.gpu_index = faiss.index_cpu_to_gpu(res, gpu_id, self.index)
                        self.index = self.gpu_index
                        gpu_success = True
                        
                        # Get GPU device name
                        cuda_device = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown"
                        print(f"Using FAISS GPU (GPU {gpu_id}: {cuda_device})")
                    except Exception as e:
                        print(f"GPU 0 failed: {e}")
                        
                        # Fallback: try index_cpu_to_all_gpus
                        try:
                            self.gpu_index = faiss.index_cpu_to_all_gpus(self.index)
                            self.index = self.gpu_index
                            gpu_success = True
                            print(f"Using FAISS GPU with index_cpu_to_all_gpus")
                        except Exception as e2:
                            print(f"index_cpu_to_all_gpus also failed: {e2}")
                    
                    if gpu_success:
                        self.use_gpu = True
                        print(f"FAISS GPU acceleration enabled (detected {faiss.get_num_gpus()} GPUs)")
                    else:
                        print("All GPU methods failed, using CPU")
                else:
                    if not torch.cuda.is_available():
                        print("PyTorch CUDA is not available, using FAISS CPU")
                    else:
                        print(f"FAISS reports {faiss.get_num_gpus()} GPUs, using CPU")
            except Exception as e:
                print(f"Failed to initialize FAISS GPU, falling back to CPU: {e}")
                self.use_gpu = False
        else:
            print(f"Using FAISS CPU (device is {device})")
        
        # Track updates for periodic index rebuilding
        self.updates_since_rebuild = 0
        self.rebuild_frequency = 10
    
    def add_batch(self, new_cnn_features: torch.Tensor, new_embeddings: torch.Tensor, 
                  new_returns: torch.Tensor):
        """Add a batch of CNN features, embeddings and returns to the buffer with FIFO replacement.
        
        Args:
            new_cnn_features: (batch_size, C, H, W) CNN features before attention
            new_embeddings: (batch_size, embedding_dim) flattened features for FAISS
            new_returns: (batch_size,) return values
        """
        batch_size = new_cnn_features.shape[0]
        
        # Normalize embeddings for cosine similarity
        new_embeddings_norm = F.normalize(new_embeddings, dim=1)
        
        if self.current_idx + batch_size <= self.max_size:
            # Simple case: no wrap-around
            self.cnn_features[self.current_idx:self.current_idx + batch_size] = new_cnn_features
            self.embeddings[self.current_idx:self.current_idx + batch_size] = new_embeddings_norm
            self.returns[self.current_idx:self.current_idx + batch_size] = new_returns
                        
            self.current_idx += batch_size
            if self.current_idx >= self.max_size:
                self.is_full = True
                self.current_idx = self.current_idx % self.max_size
                print(f"Buffer is now full, wrapping around. Current idx: {self.current_idx}")
        else:
            # Handle wrap-around case
            first_part = self.max_size - self.current_idx
            self.cnn_features[self.current_idx:] = new_cnn_features[:first_part]
            self.embeddings[self.current_idx:] = new_embeddings_norm[:first_part]
            self.returns[self.current_idx:] = new_returns[:first_part]
            
            second_part = batch_size - first_part
            self.cnn_features[:second_part] = new_cnn_features[first_part:]
            self.embeddings[:second_part] = new_embeddings_norm[first_part:]
            self.returns[:second_part] = new_returns[first_part:]
            
            self.is_full = True
            self.current_idx = second_part
        
        # Track updates and rebuild index periodically
        self.updates_since_rebuild += 1
        if self.updates_since_rebuild >= self.rebuild_frequency or (self.is_full and self.current_idx < batch_size):
            self._rebuild_index()
            self.updates_since_rebuild = 0
    
    def _rebuild_index(self):
        """Rebuild the FAISS index with current buffer contents."""
        size = self.max_size if self.is_full else self.current_idx
        if size > 0:
            embeddings_np = self.embeddings[:size].cpu().numpy()
            
            if self.use_gpu:
                try:
                    # Reset and rebuild GPU index
                    self.index.reset()
                    self.index.add(embeddings_np)
                    # print(f"Rebuilt FAISS index with {size} vectors (GPU)")
                except Exception as e:
                    print(f"GPU rebuild failed: {e}, recreating GPU index")
                    # Recreate GPU index from scratch
                    cpu_index = faiss.IndexFlatIP(self.embedding_dim)
                    cpu_index.add(embeddings_np)
                    self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index)
                    print(f"Rebuilt FAISS index with {size} vectors (GPU - recreated)")
            else:
                # For CPU index, reset and add
                self.index.reset()
                self.index.add(embeddings_np)
                # print(f"Rebuilt FAISS index with {size} vectors (CPU)")
    
    def search_neighbors(self, query_embeddings: torch.Tensor, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Search k nearest neighbors for query embeddings.
        
        Returns:
            distances: (num_queries, k) tensor of distances
            indices: (num_queries, k) tensor of indices
        """
        # Normalize queries
        query_embeddings = F.normalize(query_embeddings, dim=1)
        
        # Adjust k if buffer has fewer samples
        actual_size = self.max_size if self.is_full else self.current_idx
        k_adjusted = min(k, actual_size)
        
        if k_adjusted == 0:
            # Return empty results if buffer is empty
            return (torch.empty(query_embeddings.shape[0], 0, device=self.device),
                    torch.empty(query_embeddings.shape[0], 0, dtype=torch.long, device=self.device))
        
        # Search in FAISS
        queries_np = query_embeddings.cpu().numpy()
        distances, indices = self.index.search(queries_np, k_adjusted)
        
        # Convert to torch and ensure indices are valid
        distances_torch = torch.from_numpy(distances).to(self.device)
        indices_torch = torch.from_numpy(indices).to(self.device)
        
        # FAISS might return -1 for no match, replace with valid indices
        indices_torch = torch.clamp(indices_torch, min=0, max=actual_size-1)
        
        return distances_torch, indices_torch
    
    def get_samples(self, indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Retrieve CNN features, embeddings and returns for given indices."""
        # Ensure indices are within valid range
        max_valid_idx = self.max_size if self.is_full else self.current_idx
        valid_mask = (indices >= 0) & (indices < max_valid_idx)
        valid_indices = indices[valid_mask]
        
        if len(valid_indices) == 0:
            # Return empty tensors if no valid indices
            return (torch.empty(0, *self.cnn_feature_shape, device=self.device),
                    torch.empty(0, self.embedding_dim, device=self.device),
                    torch.empty(0, device=self.device))
        
        return self.cnn_features[valid_indices], self.embeddings[valid_indices], self.returns[valid_indices]
    
    def sample_anchors(self, n_anchors: int) -> torch.Tensor:
        """Randomly sample anchor indices from the buffer."""
        size = self.max_size if self.is_full else self.current_idx
        if size == 0:
            return torch.tensor([], dtype=torch.long, device=self.device)
        
        # Sample without replacement
        perm = torch.randperm(size, device=self.device)[:n_anchors]
        return perm
    
    @property
    def size(self) -> int:
        """Current number of samples in the buffer."""
        return self.max_size if self.is_full else self.current_idx

class NatureCNN(nn.Module):
    def __init__(self, sample_obs, args):
        super().__init__()

        self.out_features = 0
        feature_size = 256
        in_channels = sample_obs["rgb"].shape[-1]

        # Store attention type
        self.attention_type = args.attention_type
        
        # 1. Define the CNN backbone
        self.cnn_backbone = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
        )
        
        # 2. Define the attention head based on type
        if args.attention_type == "patch":
            self.attention_head = PatchAttention(64)
        elif args.attention_type == "foveal":
            self.attention_head = FovealAttention(64)
        else:
            raise ValueError(f"Unknown attention type: {args.attention_type}")
        
        # 3. Define separate fully-connected layers for value and policy pathways
        with torch.no_grad():
            dummy_input = sample_obs["rgb"].float().permute(0,3,1,2).cpu()
            n_flatten = self.cnn_backbone(dummy_input).flatten(1).shape[1]
        
        # Single feature pathway: processes attention-weighted features
        self.feature_fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n_flatten, feature_size),
            nn.ReLU()
        )
        
        self.out_features += feature_size
        
        # 4. Define the state feature extractor (if used)
        if "state" in sample_obs:
            state_size = sample_obs["state"].shape[-1]
            self.state_extractor = nn.Linear(state_size, 256)
            self.out_features += 256
        else:
            self.state_extractor = None


    def forward(self, observations, return_attention=False, return_cnn_features=False):
        """Forward pass through the feature extractor."""
        # Process RGB observations
        rgb_obs = observations["rgb"].float().permute(0, 3, 1, 2) / 255
        cnn_features = self.cnn_backbone(rgb_obs)
        attn_map, attn_params = self.attention_head(cnn_features)

        # Apply attention weighting
        weighted_features = cnn_features * attn_map

        # Extract features
        final_rgb_features = self.feature_fc(weighted_features)

        # Process state features if they exist
        if self.state_extractor:
            state_features = self.state_extractor(observations["state"])
            final_features = torch.cat([final_rgb_features, state_features], dim=1)
        else:
            final_features = final_rgb_features

        # Return based on requested outputs
        if return_attention and return_cnn_features:
            return final_features, {"rgb": attn_map}, {"rgb": attn_params}, {"rgb": cnn_features}
        elif return_attention:
            return final_features, {"rgb": attn_map}, {"rgb": attn_params}
        elif return_cnn_features:
            return final_features, {"rgb": cnn_features}
        return final_features

class Agent(nn.Module):
    def __init__(self, envs, sample_obs, args):
        super().__init__()
        self.feature_net = NatureCNN(sample_obs=sample_obs, args=args)
        # latent_size = np.array(envs.unwrapped.single_observation_space.shape).prod()
        latent_size = self.feature_net.out_features
        self.critic = nn.Sequential(
            layer_init(nn.Linear(latent_size, 512)),
            nn.ReLU(inplace=True),
            layer_init(nn.Linear(512, 1)),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(latent_size, 512)),
            nn.ReLU(inplace=True),
            layer_init(nn.Linear(512, np.prod(envs.unwrapped.single_action_space.shape)), std=0.01*np.sqrt(2)),
        )
        self.actor_logstd = nn.Parameter(torch.ones(1, np.prod(envs.unwrapped.single_action_space.shape)) * -0.5)
    def get_features(self, x, return_attention=False, return_cnn_features=False):
        if return_attention and return_cnn_features:
            return self.feature_net(x, return_attention=return_attention, return_cnn_features=return_cnn_features)
        elif return_attention:
            return self.feature_net(x, return_attention=return_attention)
        elif return_cnn_features:
            return self.feature_net(x, return_cnn_features=return_cnn_features)
        return self.feature_net(x)

    def get_value(self, x):
        features = self.feature_net(x)
        return self.critic(features)

    def get_action(self, x, deterministic=False):
        features = self.feature_net(x)
        action_mean = self.actor_mean(features)
        if deterministic:
            return action_mean
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        return probs.sample()
    def get_action_and_value(self, x, action=None):
        # Compute features once for both actor and critic
        features = self.feature_net(x)

        # Compute action distribution
        action_mean = self.actor_mean(features)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()

        # Compute value
        value = self.critic(features)

        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), value

class Logger:
    def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None:
        self.writer = tensorboard
        self.log_wandb = log_wandb
    def add_scalar(self, tag, scalar_value, step):
        if self.log_wandb:
            wandb.log({tag: scalar_value}, step=step)
        self.writer.add_scalar(tag, scalar_value, step)
    def close(self):
        self.writer.close()

def mine_contrastive_triplets_faiss(buffer: ContrastiveBuffer,
                                   n_anchors: int = 1024,
                                   top_k: int = 50, 
                                   device: str = 'cuda') -> List[Tuple[int, int, int]]:
    """
    Mine anchor-positive-negative triplets using FAISS-based efficient search.
    """
    triplets = []
    
    if buffer.size < 3:
        return triplets
    
    # Sample anchor indices
    anchor_indices = buffer.sample_anchors(min(n_anchors, buffer.size))
    if len(anchor_indices) == 0:
        return triplets
    
    # Get anchor embeddings and returns
    _, anchor_embeddings, anchor_returns = buffer.get_samples(anchor_indices)
    
    # Search for k+1 nearest neighbors (including self)
    k_search = min(top_k + 1, buffer.size)
    _, neighbor_indices = buffer.search_neighbors(anchor_embeddings, k_search)
    
    # Process each anchor
    for i, anchor_idx in enumerate(anchor_indices):
        # Get neighbor indices (excluding self which should be first)
        neighbors = neighbor_indices[i]
        
        # Handle empty neighbors case
        if neighbors.numel() == 0:
            continue
            
        # Find and exclude self
        self_mask = neighbors != anchor_idx
        neighbors = neighbors[self_mask]
        
        if len(neighbors) < 2:
            continue
        
        # Get returns for neighbors with bounds checking
        _, _, neighbor_returns = buffer.get_samples(neighbors)
        
        # Skip if we got no valid samples
        if neighbor_returns.numel() == 0:
            continue
        ret_var = neighbor_returns.var()
        
        # Skip if returns are too similar
        if ret_var < 1e-3:
            continue
        
        # Adaptive margin based on return variance
        adaptive_margin = 0.5 * neighbor_returns.std()
        median_return = neighbor_returns.median()
        
        # Find positive (higher return) and negative (lower return) examples
        pos_mask = neighbor_returns > median_return + adaptive_margin
        neg_mask = neighbor_returns < median_return - adaptive_margin
        
        pos_indices = neighbors[pos_mask]
        neg_indices = neighbors[neg_mask]
        
        # Create triplets if we have both positive and negative examples
        if len(pos_indices) > 0 and len(neg_indices) > 0:
            # Randomly select one positive and one negative
            pos_idx = pos_indices[torch.randint(len(pos_indices), (1,), device=device)].item()
            neg_idx = neg_indices[torch.randint(len(neg_indices), (1,), device=device)].item()
            triplets.append((anchor_idx.item(), pos_idx, neg_idx))
    
    return triplets

def contrastive_loss_faiss(agent: nn.Module,
                                   buffer: ContrastiveBuffer,
                                   triplets: List[Tuple[int, int, int]],
                                   margin: float = 0.5,
                                   lambda_spread: float = 0.1,
                                   device: str = 'cuda') -> Tuple[torch.Tensor, Dict]:
    """Compute feature-level contrastive loss"""
    metrics = {}
    if not triplets:
        return torch.tensor(0.0, device=device), metrics
    
    try:
        # Extract indices for vectorized processing
        anchor_indices, pos_indices, neg_indices = zip(*triplets)
        n_triplets = len(triplets)
        
        # Get stored CNN features (NOT embeddings!) for gradient flow
        all_indices = torch.tensor(anchor_indices + pos_indices + neg_indices, device=device)
        cnn_features_batch, _, _ = buffer.get_samples(all_indices)  # (3*n_triplets, C, H, W)
        
        # Check if we got valid features
        if cnn_features_batch.shape[0] == 0:
            return torch.tensor(0.0, device=device), metrics
        
        # Fresh forward pass through attention head only
        attn_maps, attn_params = agent.feature_net.attention_head(cnn_features_batch)
        
        # Compute attended features (what×where)
        B, C, H, W = cnn_features_batch.shape
        # Attention maps are already normalized for Foveal attention
        p = attn_maps  # (B, 1, H, W)
        
        # Attended features with L2 normalization for cosine similarity
        weighted_features = cnn_features_batch * p
        feat = weighted_features.flatten(1)  # (B, C*H*W)
        feat = F.normalize(feat, dim=1)  # L2 normalize for cosine distance
        
        # Split features into anchor, positive, negative
        fa = feat[:n_triplets]
        fp = feat[n_triplets:2*n_triplets]
        fn = feat[2*n_triplets:]
        
        # Cosine distance function
        def cos_dist(x, y):
            return 1 - (x * y).sum(dim=1)
        
        # Feature-level distances
        dist_pos_feat = cos_dist(fa, fp)
        dist_neg_feat = cos_dist(fa, fn)

        # Pure feature-level contrastive loss
        contrastive_loss = F.relu(margin + (dist_pos_feat - dist_neg_feat))
            
        # Extract parameters for spread regularization
        params_anchor = attn_params[:n_triplets]
        params_pos = attn_params[n_triplets:2*n_triplets]
        params_neg = attn_params[2*n_triplets:]
        
        # Spread regularization (only for Foveal attention)
        spread_reg = torch.tensor(0.0, device=device)
        if agent.feature_net.attention_type == "foveal":
            # collect all spreads (σx, σy)
            all_spreads = torch.cat([
                params_anchor[:, 2:4], params_pos[:, 2:4], params_neg[:, 2:4]
            ], dim=0)  # shape (3N, 2)

            # log-space anchor: pull toward a small-but-nonzero target
            sigma_t = 0.1                     # target spread in your normalized coords
            log_sigma_t = math.log(sigma_t + 1e-8)
            log_spreads = torch.log(all_spreads + 1e-8)

            # smooth quadratic around the target in log-space
            spread_reg = ((log_spreads - log_sigma_t) ** 2).mean() * lambda_spread
        
        # Compute final loss
        total_loss = contrastive_loss.mean() + spread_reg
        
        # Collect comprehensive metrics for logging
        with torch.no_grad():
            # Distance differences (key metric - should be positive for good separation)
            dist_diff = dist_neg_feat - dist_pos_feat

            # Active fraction (how many triplets violate the margin)
            active_mask = (margin + (dist_pos_feat - dist_neg_feat)) > 0
            
            # Get spread statistics for Foveal attention
            if agent.feature_net.attention_type == "foveal":
                all_spreads = torch.cat([
                    params_anchor[:, 2:4], params_pos[:, 2:4], params_neg[:, 2:4]
                ], dim=0)
                mean_spread_x = all_spreads[:, 0].mean().item()
                mean_spread_y = all_spreads[:, 1].mean().item()
                spread_x_std = all_spreads[:, 0].std().item()
                spread_y_std = all_spreads[:, 1].std().item()
            else:
                mean_spread_x = 0.0
                mean_spread_y = 0.0
                spread_x_std = 0.0
                spread_y_std = 0.0
            
            metrics = {
                # Feature-level metrics
                'mean_dist_pos_feat': dist_pos_feat.mean().item(),
                'mean_dist_neg_feat': dist_neg_feat.mean().item(),
                'dist_difference': dist_diff.mean().item(),
                'dist_diff_std': dist_diff.std().item(),
                'dist_diff_min': dist_diff.min().item(),
                'dist_diff_max': dist_diff.max().item(),

                # Loss components
                'contrastive_loss': contrastive_loss.mean().item(),
                'spread_reg_loss': spread_reg.item(),
                'active_fraction': active_mask.float().mean().item(),

                # Attention-specific parameters
                'mean_spread_x': mean_spread_x,
                'mean_spread_y': mean_spread_y,
                'spread_x_std': spread_x_std,
                'spread_y_std': spread_y_std,
                'attention_type': agent.feature_net.attention_type,
            }

        return total_loss, metrics
        
    except (IndexError, KeyError, RuntimeError):
        # Return zero loss if computation fails
        return torch.tensor(0.0, device=device), metrics


def create_attention_overlay(rgb_frame: np.ndarray, attention_map: torch.Tensor, alpha: float = 0.4) -> np.ndarray:
    if rgb_frame.dtype != np.uint8:
        rgb_frame = (rgb_frame * 255).astype(np.uint8)
    
    H, W, C = rgb_frame.shape
    num_cameras = C // 3  # Detect number of cameras
    
    attention = attention_map.squeeze().cpu().numpy()
    attention_resized = cv2.resize(attention, (W, H), interpolation=cv2.INTER_LINEAR)
    
    attention_norm = (attention_resized - attention_resized.min()) / (attention_resized.max() - attention_resized.min() + 1e-8)
    
    # Apply colormap
    attention_colored = cm.viridis(attention_norm)[:, :, :3]  # (H, W, 3)
    attention_colored = (attention_colored * 255).astype(np.uint8)
    
    # Replicate attention overlay for each camera
    if num_cameras > 1:
        attention_colored = np.concatenate([attention_colored] * num_cameras, axis=2) 
    
    # Blend with original image
    overlay = ((1 - alpha) * rgb_frame + alpha * attention_colored).astype(np.uint8)
    
    return overlay

def save_attention_video(frames: List[np.ndarray], output_path: str, fps: int = 30) -> None:
    """
    Save a list of frames as a video file using imageio.
    """
    if not frames:
        return

    # Check if multi-camera frames
    num_channels = frames[0].shape[2]
    if num_channels > 4:  # Multi-camera case
        num_cameras = num_channels // 3
        processed_frames = []
        
        for frame in frames:
            H, W, C = frame.shape
            # Split and concatenate cameras horizontally
            cameras = [frame[:, :, i*3:(i+1)*3] for i in range(num_cameras)]
            combined = np.concatenate(cameras, axis=1)  # Side-by-side
            processed_frames.append(combined)
        
        frames = processed_frames
    
    # Save video
    imageio.mimsave(output_path, frames, fps=fps, codec='libx264')

if __name__ == "__main__":
    args = tyro.cli(Args)
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    if args.exp_name is None:
        args.exp_name = os.path.basename(__file__)[: -len(".py")]
        run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    else:
        run_name = args.exp_name
    

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    env_kwargs = dict(obs_mode="rgb", render_mode=args.render_mode, sim_backend="physx_cuda")
    if args.control_mode is not None:
        env_kwargs["control_mode"] = args.control_mode
    eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq, **env_kwargs)
    envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, reconfiguration_freq=args.reconfiguration_freq, **env_kwargs)

    # rgbd obs mode returns a dict of data, we flatten it so there is just a rgbd key and state key
    envs = FlattenRGBDObservationWrapper(envs, rgb=True, depth=False, state=args.include_state)
    eval_envs = FlattenRGBDObservationWrapper(eval_envs, rgb=True, depth=False, state=args.include_state)

    if isinstance(envs.action_space, gym.spaces.Dict):
        envs = FlattenActionSpaceWrapper(envs)
        eval_envs = FlattenActionSpaceWrapper(eval_envs)
    if args.capture_video:
        eval_output_dir = f"runs/{run_name}/videos"
        if args.evaluate:
            eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos"
        print(f"Saving eval videos to {eval_output_dir}")
        if args.save_train_video_freq is not None:
            save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0
            envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30)
        eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.evaluate, trajectory_name="trajectory", max_steps_per_video=args.num_eval_steps, video_fps=30)
    envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=not args.partial_reset, record_metrics=True)
    eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=not args.eval_partial_reset, record_metrics=True)
    assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

    max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env)
    logger = None
    if not args.evaluate:
        print("Running training")
        if args.track:
            import wandb
            config = vars(args)
            config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset)
            config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset)
            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=False,
                config=config,
                name=run_name,
                save_code=True,
                group=args.wandb_group,
                tags=["ppo", f"{args.env_id.split('-')[0].split('Level')[0]}"],
            )
        writer = SummaryWriter(f"runs/{run_name}")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )
        logger = Logger(log_wandb=args.track, tensorboard=writer)
    else:
        print("Running evaluation")

    # ALGO Logic: Storage setup
    obs = DictArray((args.num_steps, args.num_envs), envs.single_observation_space, device=device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)

    # Contrastive learning storage
    if args.use_contrastive and args.use_faiss:
        # Temporary storage for current rollout
        rollout_cnn_features = []
        rollout_embeddings = []

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs, _ = envs.reset(seed=args.seed)
    eval_obs, _ = eval_envs.reset(seed=args.seed)
    next_done = torch.zeros(args.num_envs, device=device)
    print(f"####")
    print(f"args.num_iterations={args.num_iterations} args.num_envs={args.num_envs} args.num_eval_envs={args.num_eval_envs}")
    print(f"args.minibatch_size={args.minibatch_size} args.batch_size={args.batch_size} args.update_epochs={args.update_epochs}")
    print(f"####")
    agent = Agent(envs, sample_obs=next_obs, args=args).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    if args.checkpoint:
        agent.load_state_dict(torch.load(args.checkpoint))
    # Initialize contrastive contrastive buffer
    contrastive_buffer = None
    if args.use_contrastive and args.use_faiss:
        # Get CNN feature shape and embedding dimension with a dummy forward pass
        with torch.no_grad():
            dummy_obs = next_obs
            _, cnn_features = agent.feature_net(dummy_obs, return_cnn_features=True)
            cnn_feature_shape = cnn_features['rgb'][0].shape  # (C, H, W)
            embedding_dim = cnn_features['rgb'][0].flatten().shape[0]
        
        contrastive_buffer = ContrastiveBuffer(
            cnn_feature_shape=cnn_feature_shape,
            embedding_dim=embedding_dim,
            max_size=args.contrastive_buffer_size,
            device=device
        )
        print(f"Initialized Contrastive Buffer with CNN shape={cnn_feature_shape}, "
              f"embedding_dim={embedding_dim}, max_size={args.contrastive_buffer_size}")
        print(f"Using attention type: {args.attention_type}")

    cumulative_times = defaultdict(float)

    for iteration in range(1, args.num_iterations + 1):
        print(f"Epoch: {iteration}, global_step={global_step}")
        final_values = torch.zeros((args.num_steps, args.num_envs), device=device)
        agent.eval()
        if iteration % args.eval_freq == 1:
            print("Evaluating")
            stime = time.perf_counter()
            eval_obs, _ = eval_envs.reset()
            eval_metrics = defaultdict(list)
            num_episodes = 0
            
            # Initialize attention video recording for first environment
            attention_frames = []
            
            for step in range(args.num_eval_steps):
                with torch.no_grad():
                    # Compute action batch once for all environments
                    eval_actions = agent.get_action(eval_obs, deterministic=True)
                    
                    # Get attention for first environment from same computation if recording
                    if step < 50:  # Record first 50 steps regardless of contrastive setting
                        # Get attention map for first environment using the same obs
                        first_env_obs = {k: v[0:1] for k, v in eval_obs.items()}  # First env only
                        _, attention_maps, _ = agent.get_features(first_env_obs, return_attention=True)
                        
                        # Create attention overlay
                        rgb_frame = eval_obs["rgb"][0].cpu().numpy()  # (H, W, 3)
                        attention_map = attention_maps["rgb"][0]  # (1, H_attn, W_attn)
                        overlay_frame = create_attention_overlay(rgb_frame, attention_map)
                        attention_frames.append(overlay_frame)
                    
                    # Step with the same action batch
                    eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(eval_actions)
                    if "final_info" in eval_infos:
                        mask = eval_infos["_final_info"]
                        num_episodes += mask.sum()
                        for k, v in eval_infos["final_info"]["episode"].items():
                            eval_metrics[k].append(v)
            print(f"Evaluated {args.num_eval_steps * args.num_eval_envs} steps resulting in {num_episodes} episodes")
            for k, v in eval_metrics.items():
                mean = torch.stack(v).float().mean()
                if logger is not None:
                    logger.add_scalar(f"eval/{k}", mean, global_step)
                print(f"eval_{k}_mean={mean}")
            # Save attention video and log to wandb
            if attention_frames and logger is not None:
                # Save video locally
                attention_video_path = f"runs/{run_name}/attention_iter_{iteration}.mp4"
                os.makedirs(os.path.dirname(attention_video_path), exist_ok=True)
                save_attention_video(attention_frames, attention_video_path, fps=10)
                print(f"Saved attention video: {attention_video_path}")
                
                # Log to wandb if tracking enabled
                if args.track:
                    import wandb
                    wandb.log({
                        "attention_video": wandb.Video(attention_video_path, format="mp4"),
                        "iteration": iteration
                    }, step=global_step)
                    print(f"Logged attention video to wandb")
            
            if logger is not None:
                eval_time = time.perf_counter() - stime
                cumulative_times["eval_time"] += eval_time
                logger.add_scalar("time/eval_time", eval_time, global_step)
            if args.evaluate:
                break
        if args.save_model and iteration % args.eval_freq == 1:
            model_path = f"runs/{run_name}/ckpt_{iteration}.pt"
            torch.save(agent.state_dict(), model_path)
            print(f"model saved to {model_path}")
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow
        rollout_time = time.perf_counter()
        for step in range(0, args.num_steps):
            global_step += args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # Collect CNN features for contrastive learning
            if args.use_contrastive and args.use_faiss:
                with torch.no_grad():
                    _, cnn_features_dict = agent.feature_net(next_obs, return_cnn_features=True)
                    # Get CNN features 
                    cnn_features_batch = cnn_features_dict['rgb']
                    # flattened features for FAISS 
                    embeddings_batch = cnn_features_batch.flatten(1)
                    
                    rollout_cnn_features.append(cnn_features_batch)
                    rollout_embeddings.append(embeddings_batch)

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, terminations, truncations, infos = envs.step(action)
            next_done = torch.logical_or(terminations, truncations).to(torch.float32)
            rewards[step] = reward.view(-1) * args.reward_scale

            if "final_info" in infos:
                final_info = infos["final_info"]
                done_mask = infos["_final_info"]
                for k, v in final_info["episode"].items():
                    logger.add_scalar(f"train/{k}", v[done_mask].float().mean(), global_step)

                for k in infos["final_observation"]:
                    infos["final_observation"][k] = infos["final_observation"][k][done_mask]
                with torch.no_grad():
                    final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(infos["final_observation"]).view(-1)
        rollout_time = time.perf_counter() - rollout_time
        cumulative_times["rollout_time"] += rollout_time
        # bootstrap value according to termination and truncation
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    next_not_done = 1.0 - next_done
                    nextvalues = next_value
                else:
                    next_not_done = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                real_next_values = next_not_done * nextvalues + final_values[t] # t instead of t+1
                # next_not_done means nextvalues is computed from the correct next_obs
                # if next_not_done is 1, final_values is always 0
                # if next_not_done is 0, then use final_values, which is computed according to bootstrap_at_done
                if args.finite_horizon_gae:
                    """
                    See GAE paper equation(16) line 1, we will compute the GAE based on this line only
                    1             *(  -V(s_t)  + r_t                                                               + gamma * V(s_{t+1})   )
                    lambda        *(  -V(s_t)  + r_t + gamma * r_{t+1}                                             + gamma^2 * V(s_{t+2}) )
                    lambda^2      *(  -V(s_t)  + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2}                         + ...                  )
                    lambda^3      *(  -V(s_t)  + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + gamma^3 * r_{t+3}
                    We then normalize it by the sum of the lambda^i (instead of 1-lambda)
                    """
                    if t == args.num_steps - 1: # initialize
                        lam_coef_sum = 0.
                        reward_term_sum = 0. # the sum of the second term
                        value_term_sum = 0. # the sum of the third term
                    lam_coef_sum = lam_coef_sum * next_not_done
                    reward_term_sum = reward_term_sum * next_not_done
                    value_term_sum = value_term_sum * next_not_done

                    lam_coef_sum = 1 + args.gae_lambda * lam_coef_sum
                    reward_term_sum = args.gae_lambda * args.gamma * reward_term_sum + lam_coef_sum * rewards[t]
                    value_term_sum = args.gae_lambda * args.gamma * value_term_sum + args.gamma * real_next_values

                    advantages[t] = (reward_term_sum + value_term_sum) / lam_coef_sum - values[t]
                else:
                    delta = rewards[t] + args.gamma * real_next_values - values[t]
                    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * next_not_done * lastgaelam # Here actually we should use next_not_terminated, but we don't have lastgamlam if terminated
            returns = advantages + values

        # Add rollout data to contrastive buffer
        if args.use_contrastive and args.use_faiss and rollout_cnn_features:
            # Stack all features and returns from rollout
            all_cnn_features = torch.cat(rollout_cnn_features, dim=0)  # (num_steps * num_envs, C, H, W)
            all_embeddings = torch.cat(rollout_embeddings, dim=0)  # (num_steps * num_envs, embed_dim)
            all_returns = returns.flatten()  # (num_steps * num_envs,)
            
            # Add to buffer
            contrastive_buffer.add_batch(all_cnn_features, all_embeddings, all_returns)
            
            # Clear rollout storage for next iteration
            rollout_cnn_features = []
            rollout_embeddings = []

        # flatten the batch
        b_obs = obs.reshape((-1,))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        agent.train()
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        update_time = time.perf_counter()
        
        # whether to use contrastive learning
        should_use_contrastive = args.use_contrastive and args.use_faiss and contrastive_buffer.size > 0
        
        # Mine triplets from contrastive buffer
        triplets = []
        if should_use_contrastive and iteration % args.contrastive_update_freq == 0:
            triplets = mine_contrastive_triplets_faiss(
                buffer=contrastive_buffer,
                n_anchors=args.contrast_batch_size,
                top_k=args.contrast_top_k,
                device=device
            )
            
            # Triplets found and will be used for contrastive loss
            pass
        
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                if args.target_kl is not None and approx_kl > args.target_kl:
                    break

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                
                # Compute contrastive loss fresh for each minibatch to avoid graph reuse
                contrastive_loss = torch.tensor(0.0, device=device)
                contrastive_metrics = {}
                if triplets and args.use_faiss:
                    contrastive_loss, contrastive_metrics = contrastive_loss_faiss(
                        agent,
                        contrastive_buffer,
                        triplets,
                        margin=args.contrast_margin,
                        lambda_spread=args.lambda_spread,
                        device=device
                    )
                
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + args.lambda_contrast * contrastive_loss

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None and approx_kl > args.target_kl:
                break
        update_time = time.perf_counter() - update_time
        cumulative_times["update_time"] += update_time
        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        logger.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        logger.add_scalar("losses/value_loss", v_loss.item(), global_step)
        logger.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        logger.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        logger.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        logger.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        logger.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        logger.add_scalar("losses/explained_variance", explained_var, global_step)
        if args.use_contrastive:
            if contrastive_loss.item() > 0:
                logger.add_scalar("losses/contrastive_loss", contrastive_loss.item(), global_step)
                logger.add_scalar("contrastive/num_triplets", len(triplets) if 'triplets' in locals() else 0, global_step)
                
                # Log detailed contrastive metrics
                if contrastive_metrics:
                    for metric_name, metric_value in contrastive_metrics.items():
                        # Skip non-numeric metrics
                        if metric_name == 'attention_type':
                            continue
                        logger.add_scalar(f"contrastive/{metric_name}", metric_value, global_step)
                    
                    # Metrics are logged to tensorboard/wandb
                    
            logger.add_scalar("contrastive/update_active", 1.0 if iteration % args.contrastive_update_freq == 0 else 0.0, global_step)
            logger.add_scalar("contrastive/current_weight", args.lambda_contrast, global_step)
        
        print("SPS:", int(global_step / (time.time() - start_time)))
        logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        logger.add_scalar("time/step", global_step, global_step)
        logger.add_scalar("time/update_time", update_time, global_step)
        logger.add_scalar("time/rollout_time", rollout_time, global_step)
        logger.add_scalar("time/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step)
        for k, v in cumulative_times.items():
            logger.add_scalar(f"time/total_{k}", v, global_step)
        logger.add_scalar("time/total_rollout+update_time", cumulative_times["rollout_time"] + cumulative_times["update_time"], global_step)
        
        # Rollout data already cleared after adding to buffer
    if args.save_model and not args.evaluate:
        model_path = f"runs/{run_name}/final_ckpt.pt"
        torch.save(agent.state_dict(), model_path)
        print(f"model saved to {model_path}")

    envs.close()
    if logger is not None: logger.close()
