import gc
import os
import sys
import random
import warnings
from collections import defaultdict
from typing import Dict, List
import jsonlines

import lmdb
import msgpack_numpy
import numpy as np
import math
import time
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn  # === ADAPTERS: added ===

from image_manipulate import add_noise, black_out, defocus_blur, flare_effect, spatter, foreign_object, motion_blur, low_lighting_gradient, train_corruption_episode    
from depth_manipulate import apply_gaussian_noise, apply_missing_data, apply_multipath_interference, apply_depth_quantization, train_corruption_episode_depth

import tqdm
from gym import Space
from habitat import Config, logger
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.environments import get_env_class
from habitat_baselines.common.obs_transformers import (
    apply_obs_transforms_batch,
    apply_obs_transforms_obs_space,
    get_active_obs_transforms,
)
from habitat_baselines.common.tensorboard_utils import TensorboardWriter
from habitat_baselines.utils.common import batch_obs

from vlnce_baselines.common.aux_losses import AuxLosses
from vlnce_baselines.common.base_il_trainer import BaseVLNCETrainer
from vlnce_baselines.common.env_utils import construct_envs, construct_envs_for_rl, is_slurm_batch_job
from vlnce_baselines.common.utils import extract_instruction_tokens
from vlnce_baselines.models.graph_utils import GraphMap, MAX_DIST
from vlnce_baselines.utils import reduce_loss

from .utils import get_camera_orientations12
from .utils import (
    length2mask, dir_angle_feature_with_ele,
)
from vlnce_baselines.common.utils import dis_to_con, gather_list_and_concat
from habitat_extensions.measures import NDTW, StepsTaken
from fastdtw import fastdtw

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    import tensorflow as tf  # noqa: F401

import torch.distributed as distr
import gzip
import json
from copy import deepcopy
from torch.cuda.amp import autocast, GradScaler
from vlnce_baselines.common.ops import pad_tensors_wgrad, gen_seq_masks
from torch.nn.utils.rnn import pad_sequence


# ======================================================================================
# === ROBUST ADAPTERS: Multi-scale adapters for robustness ===
# ======================================================================================

class CLIPBasedAdapter(nn.Module):
    """
    Adapter that duplicates and fine-tunes the last 20% of CLIP ViT-B/32 layers.
    This provides a warm start with pre-trained visual understanding.
    """
    def __init__(self, device="cuda", dropout=0.1):
        super().__init__()
        import copy
        import clip
        
        # Load CLIP and extract last 2 layers (20% of 12 layers)
        clip_model, _ = clip.load("ViT-B/32", device=device)
        original_resblocks = clip_model.visual.transformer.resblocks
        
        # CLIP internal dim is 768, but output is projected to 512
        # We need to work in the 512 space since that's what we get
        self.hidden_dim = 512
        
        # Project to CLIP's internal dimension for processing
        self.input_projection = nn.Linear(512, 768).to(device)
        
        # Extract and copy last 2 transformer layers
        self.adapter_layers = nn.ModuleList()
        for i in range(10, 12):  # Last 2 layers
            layer_copy = copy.deepcopy(original_resblocks[i])
            # Move to correct device and ensure float32 precision
            layer_copy = layer_copy.to(device).float()
            # Make sure parameters are trainable
            for param in layer_copy.parameters():
                param.requires_grad_(True)
            self.adapter_layers.append(layer_copy)
        
        # Project back to 512 dimension
        self.output_projection = nn.Sequential(
            nn.Linear(768, 512),
            nn.LayerNorm(512),
            nn.Dropout(dropout)
        ).to(device)
        
        # Residual connection with learned gating
        self.residual_gate = nn.Sequential(
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        ).to(device)
        
        print(f"✅ CLIPBasedAdapter initialized with {sum(p.numel() for p in self.parameters()):,} parameters")
    
    def forward(self, x):
        # x: (B*V, 512) - CLIP features
        orig_x = x
        batch_size = x.shape[0]
        
        # Project to CLIP internal dimension
        x = self.input_projection(x)  # (B*V, 768)
        
        # Reshape for ResidualAttentionBlock: needs (batch, seq_len, dim)
        # We treat each feature as a single token, so seq_len = 1
        x = x.unsqueeze(1)  # (B*V, 1, 768)
        
        # Process through copied CLIP layers
        for layer in self.adapter_layers:
            x = layer(x)  # (B*V, 1, 768)
        
        # Reshape back to 2D
        x = x.squeeze(1)  # (B*V, 768)
        
        # Project back to 512
        x = self.output_projection(x)  # (B*V, 512)
        
        # Residual connection with gating
        gate = self.residual_gate(orig_x)
        output = orig_x + gate * x
        
        return output

class MultiScaleRobustAdapter(nn.Module):
    """
    Multi-scale robust adapter that processes features at different scales
    with uncertainty-aware processing for robustness.
    """
    def __init__(self, dim: int, bottleneck: int = 128, dropout: float = 0.1, num_scales: int = 3):
        super().__init__()
        self.dim = dim
        self.num_scales = num_scales
        
        # Multi-scale processing branches
        self.scale_branches = nn.ModuleList()
        for i in range(num_scales):
            scale_bottleneck = bottleneck // (2 ** i)  # Different scales use different capacities
            branch = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, scale_bottleneck),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(scale_bottleneck, scale_bottleneck),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(scale_bottleneck, dim)
            )
            self.scale_branches.append(branch)
        
        # Uncertainty estimation for each scale
        self.uncertainty_estimators = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, bottleneck // 4),
                nn.GELU(),
                nn.Linear(bottleneck // 4, 1),
                nn.Softplus()  # Ensures positive uncertainty
            ) for _ in range(num_scales)
        ])
        
        # Adaptive fusion network
        self.fusion_gate = nn.Sequential(
            nn.Linear(dim + num_scales, bottleneck),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(bottleneck, num_scales),
            nn.Softmax(dim=-1)
        )
        
        # Robust residual connection with gating
        self.residual_gate = nn.Sequential(
            nn.Linear(dim, bottleneck // 2),
            nn.GELU(),
            nn.Linear(bottleneck // 2, 1),
            nn.Sigmoid()
        )
        
        # Feature refinement network
        self.refiner = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, bottleneck),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(bottleneck, dim)
        )
    
    def forward(self, x):  # x: (..., D)
        orig_shape = x.shape
        x_flat = x.reshape(-1, self.dim)  # (N, D)
        
        # Process through multiple scales
        scale_outputs = []
        scale_uncertainties = []
        
        for i, (branch, uncertainty_est) in enumerate(zip(self.scale_branches, self.uncertainty_estimators)):
            # Add light noise for robustness during training
            if self.training:
                noise_level = 0.05  # Fixed light noise level
                x_noisy = x_flat + torch.randn_like(x_flat) * noise_level
            else:
                x_noisy = x_flat
            
            scale_out = branch(x_noisy)
            uncertainty = uncertainty_est(x_flat)
            
            scale_outputs.append(scale_out)
            scale_uncertainties.append(uncertainty.squeeze(-1))
        
        # Stack outputs for fusion
        scale_stack = torch.stack(scale_outputs, dim=-2)  # (N, num_scales, D)
        uncertainty_stack = torch.stack(scale_uncertainties, dim=-1)  # (N, num_scales)
        
        # Compute adaptive fusion weights
        fusion_input = torch.cat([x_flat, uncertainty_stack], dim=-1)
        fusion_weights = self.fusion_gate(fusion_input)  # (N, num_scales)
        
        # Uncertainty-weighted fusion
        reliability_weights = 1.0 / (uncertainty_stack + 1e-6)  # Higher reliability = lower uncertainty
        combined_weights = fusion_weights * reliability_weights
        combined_weights = combined_weights / (combined_weights.sum(dim=-1, keepdim=True) + 1e-8)
        
        # Fuse multi-scale outputs
        fused_output = torch.sum(scale_stack * combined_weights.unsqueeze(-1), dim=-2)  # (N, D)
        
        # Refine fused features
        refined_output = self.refiner(fused_output)
        
        # Adaptive residual connection
        residual_weight = self.residual_gate(x_flat)  # (N, 1)
        
        # Final output with adaptive residual
        output = x_flat + residual_weight * refined_output
        
        return output.reshape(orig_shape)


class TokenRobustAdapter(nn.Module):
    """Robust adapter for token embeddings with corruption awareness."""
    def __init__(self, dim: int, bottleneck: int = 128, dropout: float = 0.1, num_scales: int = 3):
        super().__init__()
        self.core = MultiScaleRobustAdapter(dim, bottleneck, dropout, num_scales)

    def forward(self, x):  # x: (B, T, D) or (T, D)
        orig = x.shape
        y = self.core(x.reshape(-1, orig[-1]))
        return y.reshape(orig)


def _reliability_from_features(x, mask, alpha: float = 2.0):
    """
    Compute per-view reliabilities from feature quality indicators.
    x: (B, V, D), mask: (B, V) in {0,1} or bool
    returns w: (B, V) in [0,1], higher = more reliable.
    """
    with torch.no_grad():
        B, V, D = x.shape
        mask_f = mask.float()
        
        # Feature norm-based reliability
        norms = torch.norm(x, dim=-1)  # (B, V)
        mean_norm = (norms * mask_f).sum(dim=1, keepdim=True) / (mask_f.sum(dim=1, keepdim=True) + 1e-6)
        var_norm = (((norms - mean_norm) ** 2) * mask_f).sum(dim=1, keepdim=True) / (mask_f.sum(dim=1, keepdim=True) + 1e-6)
        std_norm = (var_norm + 1e-6).sqrt()
        z_norm = (norms - mean_norm) / (std_norm + 1e-6)
        norm_reliability = torch.sigmoid(alpha * z_norm)
        
        # Feature consistency-based reliability
        x_centered = x - torch.mean(x, dim=-1, keepdim=True)
        feature_std = torch.std(x_centered, dim=-1)  # (B, V)
        mean_std = (feature_std * mask_f).sum(dim=1, keepdim=True) / (mask_f.sum(dim=1, keepdim=True) + 1e-6)
        var_std = (((feature_std - mean_std) ** 2) * mask_f).sum(dim=1, keepdim=True) / (mask_f.sum(dim=1, keepdim=True) + 1e-6)
        std_std = (var_std + 1e-6).sqrt()
        z_std = (feature_std - mean_std) / (std_std + 1e-6)
        consistency_reliability = torch.sigmoid(-alpha * torch.abs(z_std))  # Lower std variation = higher reliability
        
        # Combine reliability measures
        combined_reliability = (norm_reliability + consistency_reliability) / 2.0
        
        # Apply mask and ensure positive weights
        w = combined_reliability * mask_f
        w = w + 1e-6  # Ensure numerical stability
        
        return w
# ======================================================================================


@baseline_registry.register_trainer(name="SS-ETP")
class RLTrainer(BaseVLNCETrainer):
    def __init__(self, config=None):
        # ------------------------------------------------------------------
        # Defensive patch: ensure `config.local_rank` is always defined.
        # Some standalone scripts/tests instantiate RLTrainer directly
        # without using `run.py`/distributed launch utilities that usually
        # populate this field, causing AttributeError when the trainer
        # accesses it during `_make_dirs` or `_set_config`.
        # We add a default of 0 (single-GPU) if the key is missing.
        # ------------------------------------------------------------------
        if config is not None:
            try:
                # Attempt to access – will raise AttributeError if absent
                _ = config.local_rank  # noqa: F841
            except AttributeError:
                config.defrost()
                config.local_rank = 0
                config.freeze()

        super().__init__(config)
        self.max_len = int(config.IL.max_traj_len) #  * 0.97 transfered gt path got 0.96 spl

        # === CLIP-BASED ADAPTERS: warm start with pre-trained CLIP layers ===
        self.adapter_use_image = False       # turn ON/OFF image feature adapter - DISABLED
        self.adapter_use_text  = False       # set True to adapt text embeddings too - DISABLED
        self.adapter_bottleneck = 256        # not used for CLIP-based adapter
        self.adapter_dropout   = 0.1         # dropout for adapter layers
        self.adapter_num_scales = 3          # not used for CLIP-based adapter
        self.adapter_only_finetune = True    # freeze base policy; train adapters only
        self.adapter_conf_alpha = 2.0        # reliability sharpness for view fusion
        
        # Adaptive corruption scheduling
        self.corruption_start_prob = 0.3     # initial corruption probability
        self.corruption_end_prob = 0.8       # final corruption probability
        self.corruption_warmup_steps = 1000  # steps to reach full corruption
        self.current_step = 0                 # track training steps for scheduling

        self.img_adapter: nn.Module = None
        self.txt_adapter: nn.Module = None
        self._adapters_in_optim = False
        self._pending_adapter_state = None   # loaded from ckpt if available

    def _get_depth_adapter_state_dict(self):
        """Get state dict for depth adapters."""
        depth_encoder = self.policy.net.depth_encoder
        if hasattr(self.policy.net, 'module'):
            depth_encoder = self.policy.net.module.depth_encoder
            
        if not getattr(depth_encoder, 'use_depth_adapters', False):
            return None
            
        # Collect adapter state dicts from all blocks
        adapter_state = {}
        backbone = depth_encoder.visual_encoder.backbone
        stage_layers = [backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4]
        
        for stage_idx, layer in enumerate(stage_layers, start=1):
            for block_idx, block in enumerate(layer):
                if hasattr(block, 'depth_adapter'):
                    key = f"stage_{stage_idx}_block_{block_idx}"
                    adapter_state[key] = block.depth_adapter.state_dict()
                    
        return adapter_state

    def _load_depth_adapter_state_dict(self, adapter_state):
        """Load state dict for depth adapters."""
        if adapter_state is None:
            return
            
        depth_encoder = self.policy.net.depth_encoder
        if hasattr(self.policy.net, 'module'):
            depth_encoder = self.policy.net.module.depth_encoder
            
        if not getattr(depth_encoder, 'use_depth_adapters', False):
            logger.warning("Trying to load depth adapter state but depth adapters are not enabled")
            return
            
        backbone = depth_encoder.visual_encoder.backbone
        stage_layers = [backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4]
        
        loaded_count = 0
        for stage_idx, layer in enumerate(stage_layers, start=1):
            for block_idx, block in enumerate(layer):
                if hasattr(block, 'depth_adapter'):
                    key = f"stage_{stage_idx}_block_{block_idx}"
                    if key in adapter_state:
                        try:
                            block.depth_adapter.load_state_dict(adapter_state[key], strict=False)
                            loaded_count += 1
                        except Exception as e:
                            logger.warning(f"Failed to load adapter state for {key}: {e}")
                            
        logger.info(f"Loaded state for {loaded_count} depth adapters")

    def _make_dirs(self):
        if self.config.local_rank == 0:
            self._make_ckpt_dir()
            # os.makedirs(self.lmdb_features_dir, exist_ok=True)
            if self.config.EVAL.SAVE_RESULTS:
                self._make_results_dir()

    # === ADAPTERS: save adapter weights along with policy ===
    def save_checkpoint(self, iteration: int):
        ckpt = {
            "state_dict": self.policy.state_dict(),
            "config": self.config,
            "optim_state": self.optimizer.state_dict(),
            "iteration": iteration,
        }
        # Check for any type of adapters
        depth_encoder = self.policy.net.depth_encoder
        if hasattr(self.policy.net, 'module'):
            depth_encoder = self.policy.net.module.depth_encoder
        has_depth_adapters = getattr(depth_encoder, 'use_depth_adapters', False)
        
        if self.img_adapter is not None or self.txt_adapter is not None or has_depth_adapters:
            ckpt["adapter_state"] = {
                "img": None if self.img_adapter is None else (
                    self.img_adapter.module.state_dict() if hasattr(self.img_adapter, 'module') 
                    else self.img_adapter.state_dict()
                ),
                "txt": None if self.txt_adapter is None else (
                    self.txt_adapter.module.state_dict() if hasattr(self.txt_adapter, 'module') 
                    else self.txt_adapter.state_dict()
                ),
                "depth": None if not has_depth_adapters else self._get_depth_adapter_state_dict(),
            }
        torch.save(
            obj=ckpt,
            f=os.path.join(self.config.CHECKPOINT_FOLDER, f"ckpt.iter{iteration}.pth"),
        )

    def _set_config(self):
        self.split = self.config.TASK_CONFIG.DATASET.SPLIT
        self.config.defrost()
        self.config.TASK_CONFIG.TASK.NDTW.SPLIT = self.split
        self.config.TASK_CONFIG.TASK.SDTW.SPLIT = self.split
        self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        self.config.SIMULATOR_GPU_IDS = self.config.SIMULATOR_GPU_IDS[self.config.local_rank]
        self.config.use_pbar = not is_slurm_batch_job()
        ''' if choosing image '''
        resize_config = self.config.RL.POLICY.OBS_TRANSFORMS.RESIZER_PER_SENSOR.SIZES
        crop_config = self.config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER_PER_SENSOR.SENSOR_CROPS
        task_config = self.config.TASK_CONFIG
        camera_orientations = get_camera_orientations12()
        for sensor_type in ["RGB", "DEPTH"]:
            resizer_size = dict(resize_config)[sensor_type.lower()]
            cropper_size = dict(crop_config)[sensor_type.lower()]
            sensor = getattr(task_config.SIMULATOR, f"{sensor_type}_SENSOR")
            for action, orient in camera_orientations.items():
                camera_template = f"{sensor_type}_{action}"
                camera_config = deepcopy(sensor)
                camera_config.ORIENTATION = camera_orientations[action]
                camera_config.UUID = camera_template.lower()
                setattr(task_config.SIMULATOR, camera_template, camera_config)
                task_config.SIMULATOR.AGENT_0.SENSORS.append(camera_template)
                resize_config.append((camera_template.lower(), resizer_size))
                crop_config.append((camera_template.lower(), cropper_size))
        self.config.RL.POLICY.OBS_TRANSFORMS.RESIZER_PER_SENSOR.SIZES = resize_config
        self.config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER_PER_SENSOR.SENSOR_CROPS = crop_config
        self.config.TASK_CONFIG = task_config
        self.config.SENSORS = task_config.SIMULATOR.AGENT_0.SENSORS
        if self.config.VIDEO_OPTION:
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP_VLNCE")
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("DISTANCE_TO_GOAL")
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("SUCCESS")
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("SPL")
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)
            shift = 0.
            orient_dict = {
                'Back': [0, math.pi + shift, 0],            # Back
                'Down': [-math.pi / 2, 0 + shift, 0],       # Down
                'Front':[0, 0 + shift, 0],                  # Front
                'Right':[0, math.pi / 2 + shift, 0],        # Right
                'Left': [0, 3 / 2 * math.pi + shift, 0],    # Left
                'Up':   [math.pi / 2, 0 + shift, 0],        # Up
            }
            sensor_uuids = []
            H = 224
            for sensor_type in ["RGB"]:
                sensor = getattr(self.config.TASK_CONFIG.SIMULATOR, f"{sensor_type}_SENSOR")
                for camera_id, orient in orient_dict.items():
                    camera_template = f"{sensor_type}{camera_id}"
                    camera_config = deepcopy(sensor)
                    camera_config.WIDTH = H
                    camera_config.HEIGHT = H
                    camera_config.ORIENTATION = orient
                    camera_config.UUID = camera_template.lower()
                    camera_config.HFOV = 90
                    sensor_uuids.append(camera_config.UUID)
                    setattr(self.config.TASK_CONFIG.SIMULATOR, camera_template, camera_config)
                    self.config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS.append(camera_template)
        self.config.freeze()

        self.world_size = self.config.GPU_NUMBERS
        self.local_rank = self.config.local_rank
        self.batch_size = self.config.IL.batch_size
        torch.cuda.set_device(self.device)
        if self.world_size > 1:
            distr.init_process_group(backend='nccl', init_method='env://')
            self.device = self.config.TORCH_GPU_IDS[self.local_rank]
            self.config.defrost()
            self.config.TORCH_GPU_ID = self.config.TORCH_GPU_IDS[self.local_rank]
            self.config.freeze()
            torch.cuda.set_device(self.device)

    def _init_envs(self):
        # for DDP to load different data
        self.config.defrost()
        self.config.TASK_CONFIG.SEED = self.config.TASK_CONFIG.SEED + self.local_rank
        self.config.freeze()

        self.envs = construct_envs(
            self.config, 
            get_env_class(self.config.ENV_NAME),
            auto_reset_done=False
        )
        env_num = self.envs.num_envs
        dataset_len = sum(self.envs.number_of_episodes)
        logger.info(f'LOCAL RANK: {self.local_rank}, ENV NUM: {env_num}, DATASET LEN: {dataset_len}')
        observation_space = self.envs.observation_spaces[0]
        action_space = self.envs.action_spaces[0]
        self.obs_transforms = get_active_obs_transforms(self.config)
        observation_space = apply_obs_transforms_obs_space(
            observation_space, self.obs_transforms
        )

        return observation_space, action_space

    def _ensure_adapters(self, rgb_dim: int = None, txt_dim: int = None):
        """ROBUST ADAPTERS: create CLIP-based adapters upfront (called before DDP)."""
        if self.adapter_use_image and self.img_adapter is None and rgb_dim is not None:
            # Use CLIP-based adapter for better initialization
            self.img_adapter = CLIPBasedAdapter(
                device=self.device,
                dropout=self.adapter_dropout
            ).to(self.device)
            logger.info(f"Created CLIP-based image adapter with pre-trained CLIP layers")
            
        if self.adapter_use_text and self.txt_adapter is None and txt_dim is not None:
            self.txt_adapter = TokenRobustAdapter(
                txt_dim, 
                self.adapter_bottleneck, 
                self.adapter_dropout,
                self.adapter_num_scales
            ).to(self.device)
            logger.info(f"Created robust text adapter: {txt_dim} -> {self.adapter_bottleneck} (scales={self.adapter_num_scales})")

        # Handle parameter freezing for adapter-only training
        # We'll freeze parameters AFTER DDP wrapping to avoid the "no gradient" error
    
    def _augment_rgb_in_batch_train(self, batch, episode_id):
        """Enhanced corruption augmentation with diverse corruption types for robust adapter training."""
        rgb_keys = [key for key in batch.keys() if key.startswith('rgb')]  
        for rgb_key in rgb_keys:
            if rgb_key in batch:
                # Convert tensor to numpy array for processing
                rgb_tensor = batch[rgb_key]
                if rgb_tensor.dim() == 4:  # (batch_size, height, width, channels)
                    batch_size = rgb_tensor.shape[0]
                    for i in range(batch_size):
                        # Extract single image and convert to numpy
                        img_np = rgb_tensor[i].cpu().numpy()
                        img_corrupted = add_noise(img_np, intensity=0.6)
                        # Convert back to tensor with proper dtype and device
                        batch[rgb_key][i] = torch.from_numpy(img_corrupted).to(dtype=rgb_tensor.dtype, device=rgb_tensor.device)
        return batch

    def _augment_depth_in_batch_train(self, batch, episode_id, corruption_type=0):
        """Apply depth corruption augmentation for robust depth adapter training.
        Args:
            corruption_type: Type of corruption to apply (0: gaussian noise, 1: missing data, 2: multipath interference, 3: quantization)
        """
        # Only apply depth corruption if depth adapters are enabled
        depth_encoder = self.policy.net.module.depth_encoder if hasattr(self.policy.net, 'module') else self.policy.net.depth_encoder
        if not getattr(depth_encoder, 'use_depth_adapters', False):
            logger.info(f"Depth adapters not enabled, skipping corruption. corruption_type={corruption_type}")
            return batch
            
        depth_keys = [key for key in batch.keys() if key.startswith('depth')]
        
        if not hasattr(self, '_corruption_logged'):
            logger.info(f"Applying depth corruption type {corruption_type} to keys: {depth_keys}")
            self._corruption_logged = True
        
        # Apply corruption to all depth images
        for depth_key in depth_keys:
            if depth_key in batch:
                depth_tensor = batch[depth_key]
                if depth_tensor.dim() == 4:  # (batch_size, height, width, channels)
                    batch_size = depth_tensor.shape[0]
                    original_stats = []
                    corrupted_stats = []
                    
                    for i in range(batch_size):
                        img_np = depth_tensor[i].cpu().numpy()
                        if corruption_type == 0:
                            img_corrupted = apply_gaussian_noise(img_np)
                        elif corruption_type == 1:
                            img_corrupted = apply_missing_data(img_np)
                        elif corruption_type == 2:
                            img_corrupted = apply_multipath_interference(img_np)
                        elif corruption_type == 3:
                            img_corrupted = apply_depth_quantization(img_np)
                        else:
                            img_corrupted = img_np
                            
                        batch[depth_key][i] = torch.from_numpy(img_corrupted).to(dtype=depth_tensor.dtype, device=depth_tensor.device)
        return batch

    def _initialize_policy(
        self,
        config: Config,
        load_from_ckpt: bool,
        observation_space: Space,
        action_space: Space,
    ):
        start_iter = 0
        policy = baseline_registry.get_policy(self.config.MODEL.policy_name)
        self.policy = policy.from_config(
            config=config,
            observation_space=observation_space,
            action_space=action_space,
        )
        ''' initialize the waypoint predictor here '''
        from vlnce_baselines.waypoint_pred.TRM_net import BinaryDistPredictor_TRM
        self.waypoint_predictor = BinaryDistPredictor_TRM(device=self.device)
        cwp_fn = 'data/wp_pred/check_cwp_bestdist_hfov63' if self.config.MODEL.task_type == 'rxr' else 'data/wp_pred/check_cwp_bestdist_hfov90'
        self.waypoint_predictor.load_state_dict(torch.load(cwp_fn, map_location = torch.device('cpu'))['predictor']['state_dict'])
        for param in self.waypoint_predictor.parameters():
            param.requires_grad_(False)

        self.policy.to(self.device)
        self.waypoint_predictor.to(self.device)
        self.num_recurrent_layers = self.policy.net.num_recurrent_layers

        # === ADAPTERS: Create adapters BEFORE DDP wrapper ===
        # Use actual model feature dimensions from config
        if self.adapter_use_image or self.adapter_use_text:
            rgb_dim = self.config.MODEL.RGB_ENCODER.output_size if self.adapter_use_image else None
            txt_dim = 768 if self.adapter_use_text else None  # BERT-style text embeddings are typically 768
            self._ensure_adapters(rgb_dim=rgb_dim, txt_dim=txt_dim)
            logger.info(f"Created adapters before DDP: image={self.adapter_use_image} (dim={rgb_dim}), text={self.adapter_use_text} (dim={txt_dim})")

        if self.config.GPU_NUMBERS > 1:
            print('Using', self.config.GPU_NUMBERS,'GPU!')
            # find_unused_parameters=True to handle unused parameters in model
            self.policy.net = DDP(self.policy.net.to(self.device), device_ids=[self.device],
                output_device=self.device, find_unused_parameters=True, broadcast_buffers=False)
            
            # === ADAPTERS: Wrap adapters in DDP too if they exist ===
            if self.img_adapter is not None:
                self.img_adapter = DDP(self.img_adapter, device_ids=[self.device],
                    output_device=self.device, find_unused_parameters=True, broadcast_buffers=False)
            if self.txt_adapter is not None:
                self.txt_adapter = DDP(self.txt_adapter, device_ids=[self.device],
                    output_device=self.device, find_unused_parameters=True, broadcast_buffers=False)

        # === ADAPTERS: Freeze policy parameters AFTER DDP wrapping ===
        depth_encoder = self.policy.net.module.depth_encoder if hasattr(self.policy.net, 'module') else self.policy.net.depth_encoder
        has_depth_adapters = getattr(depth_encoder, 'use_depth_adapters', False)
        
        if self.adapter_only_finetune and (self.img_adapter is not None or self.txt_adapter is not None or has_depth_adapters):
            # First, freeze all policy parameters
            for p in self.policy.parameters():
                p.requires_grad_(False)
            logger.info("Froze base policy parameters for adapter-only training")
            
            # Then, explicitly unfreeze adapter parameters
            if self.img_adapter is not None:
                for p in self.img_adapter.parameters():
                    p.requires_grad_(True)
                logger.info("Unfroze image adapter parameters")
            if self.txt_adapter is not None:
                for p in self.txt_adapter.parameters():
                    p.requires_grad_(True)
                logger.info("Unfroze text adapter parameters")
            if has_depth_adapters:
                # Use the depth encoder's method to freeze base and unfreeze adapters
                depth_encoder.freeze_base_parameters()
                
                # CRITICAL FIX: Explicitly ensure adapter parameters are unfrozen
                # This fixes the bug where adapters were frozen and never unfrozen
                backbone = depth_encoder.visual_encoder.backbone
                stage_layers = [backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4]
                adapter_count = 0
                for layer in stage_layers:
                    for block in layer:
                        if hasattr(block, 'depth_adapter'):
                            for param in block.depth_adapter.parameters():
                                param.requires_grad_(True)
                                adapter_count += 1
                
                logger.info(f"Froze depth encoder base parameters, unfroze {adapter_count} adapter parameters")
        # === ADAPTERS: Create optimizer with both policy and adapter parameters ===
        params_to_optimize = []
        
        # Check if we have depth adapters  
        depth_encoder = self.policy.net.module.depth_encoder if hasattr(self.policy.net, 'module') else self.policy.net.depth_encoder
        has_depth_adapters = getattr(depth_encoder, 'use_depth_adapters', False)
        
        # Only include policy parameters if not doing adapter-only finetuning
        if not (self.adapter_only_finetune and (self.img_adapter is not None or self.txt_adapter is not None or has_depth_adapters)):
            params_to_optimize.extend(self.policy.parameters())
        
        # Always include adapter parameters if they exist
        if self.img_adapter is not None:
            params_to_optimize.extend(self.img_adapter.parameters())
        if self.txt_adapter is not None:
            params_to_optimize.extend(self.txt_adapter.parameters())
        if has_depth_adapters:
            depth_adapter_params = list(depth_encoder.get_adapter_parameters())
            
            # CRITICAL DEBUG: Check if adapter parameters have requires_grad=True
            adapter_trainable = sum(1 for p in depth_adapter_params if p.requires_grad)
            adapter_total = len(depth_adapter_params)
            logger.info(f"Depth adapter parameters: {adapter_total} total, {adapter_trainable} trainable")
            
            if adapter_trainable == 0 and adapter_total > 0:
                logger.error("❌ CRITICAL BUG: Depth adapter parameters have requires_grad=False!")
                logger.error("This will cause training failure - adapters won't update")
            elif adapter_trainable == adapter_total:
                logger.info("✅ All depth adapter parameters are trainable")
            
            params_to_optimize.extend(depth_adapter_params)
        
        self.optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.IL.lr)
        
        # Log parameter counts for debugging
        total_params = sum(p.numel() for p in params_to_optimize)
        trainable_params = sum(p.numel() for p in params_to_optimize if p.requires_grad)
        logger.info(f"Optimizer created: {total_params} total params, {trainable_params} trainable params")
        
        # CRITICAL CHECK: Verify optimizer has trainable parameters
        if trainable_params == 0:
            logger.error("❌ FATAL ERROR: Optimizer has 0 trainable parameters!")
            logger.error("Training will not work - no parameters will be updated")
            raise RuntimeError("Optimizer has no trainable parameters - training setup failed")
        else:
            logger.info("✅ Optimizer setup successful with trainable parameters")
        
        if self.adapter_only_finetune and (self.img_adapter is not None or self.txt_adapter is not None):
            logger.info("ADAPTER-ONLY MODE: Base policy frozen, only adapters will be trained")

        ckpt_dict = None  # === ADAPTERS: keep for adapter state
        adapter_ckpt_dict = None  # === ADAPTERS: separate checkpoint for adapters
        
        if load_from_ckpt:
            if config.IL.is_requeue:
                import glob
                ckpt_list = list(filter(os.path.isfile, glob.glob(config.CHECKPOINT_FOLDER + "/*")) )
                ckpt_list.sort(key=os.path.getmtime)
                ckpt_path = ckpt_list[-1]
            else:
                ckpt_path = config.IL.ckpt_to_load
            
            # === ADAPTERS: Check if we should use separate checkpoints ===
            # If we have a separate base policy checkpoint specified, use it for base policy
            base_policy_ckpt = getattr(config.IL, 'base_policy_ckpt', None)
            if base_policy_ckpt and os.path.exists(base_policy_ckpt) and self.adapter_only_finetune:
                logger.info(f"Loading base policy from: {base_policy_ckpt}")
                logger.info(f"Loading adapters from: {ckpt_path}")
                base_ckpt_dict = self.load_checkpoint(base_policy_ckpt, map_location="cpu")
                adapter_ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
                ckpt_dict = base_ckpt_dict  # Use base policy checkpoint for main loading
            else:
                # Use same checkpoint for both base policy and adapters (current behavior)
                ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
                adapter_ckpt_dict = ckpt_dict
            # === ADAPTERS: For adapter fine-tuning, always start from iteration 0 ===
            # Reset when any type of adapter (image, text, OR depth) is being trained exclusively.
            has_depth_adapters_load = adapter_ckpt_dict.get("adapter_state", {}).get("depth") is not None if adapter_ckpt_dict else False
            if self.adapter_only_finetune and (
                (self.img_adapter is not None) or (self.txt_adapter is not None) or has_depth_adapters or has_depth_adapters_load
            ):
                start_iter = 0
                logger.info("ADAPTER FINE-TUNING: Starting from iteration 0 (ignoring checkpoint iteration)")
            else:
                start_iter = ckpt_dict["iteration"]

            if 'module' in list(ckpt_dict['state_dict'].keys())[0] and self.config.GPU_NUMBERS == 1:
                self.policy.net = torch.nn.DataParallel(self.policy.net.to(self.device),
                    device_ids=[self.device], output_device=self.device)
                self.policy.load_state_dict(ckpt_dict["state_dict"], strict=False)
                self.policy.net = self.policy.net.module
                self.waypoint_predictor = torch.nn.DataParallel(self.waypoint_predictor.to(self.device),
                    device_ids=[self.device], output_device=self.device)
            else:
                self.policy.load_state_dict(ckpt_dict["state_dict"], strict=False)
            if config.IL.is_requeue:
                self.optimizer.load_state_dict(ckpt_dict["optim_state"])
            logger.info(f"Loaded weights from checkpoint: {ckpt_path}, iteration: {start_iter}")

        # === ADAPTERS: load adapter state from checkpoint if available ===
        if adapter_ckpt_dict is not None and "adapter_state" in adapter_ckpt_dict:
            try:
                adapter_state = adapter_ckpt_dict["adapter_state"]
                if adapter_state.get("img") is not None and self.img_adapter is not None:
                    # Handle DDP wrapper
                    adapter = self.img_adapter.module if hasattr(self.img_adapter, 'module') else self.img_adapter
                    adapter.load_state_dict(adapter_state["img"], strict=False)
                    logger.info("Loaded image adapter state from checkpoint")
                if adapter_state.get("txt") is not None and self.txt_adapter is not None:
                    # Handle DDP wrapper
                    adapter = self.txt_adapter.module if hasattr(self.txt_adapter, 'module') else self.txt_adapter
                    adapter.load_state_dict(adapter_state["txt"], strict=False)
                    logger.info("Loaded text adapter state from checkpoint")
                if adapter_state.get("depth") is not None:
                    self._load_depth_adapter_state_dict(adapter_state["depth"])
                    logger.info("Loaded depth adapter state from checkpoint")
            except Exception as e:
                logger.warning(f"Failed to load adapter state from checkpoint: {e}")
        else:
            if adapter_ckpt_dict is not None:
                logger.warning("No adapter_state found in checkpoint - adapters will use random initialization")

        params = sum(param.numel() for param in self.policy.parameters())
        params_t = sum(
            p.numel() for p in self.policy.parameters() if p.requires_grad
        )
        
        # === ADAPTERS: count adapter parameters ===
        adapter_params = 0
        adapter_params_t = 0
        if self.img_adapter is not None:
            # Handle DDP wrapper
            adapter = self.img_adapter.module if hasattr(self.img_adapter, 'module') else self.img_adapter
            img_params = sum(p.numel() for p in adapter.parameters())
            img_params_t = sum(p.numel() for p in adapter.parameters() if p.requires_grad)
            adapter_params += img_params
            adapter_params_t += img_params_t
            logger.info(f"Image adapter parameters: {img_params/1e3:.1f}K. Trainable: {img_params_t/1e3:.1f}K")
            
        if self.txt_adapter is not None:
            # Handle DDP wrapper
            adapter = self.txt_adapter.module if hasattr(self.txt_adapter, 'module') else self.txt_adapter
            txt_params = sum(p.numel() for p in adapter.parameters())
            txt_params_t = sum(p.numel() for p in adapter.parameters() if p.requires_grad)
            adapter_params += txt_params
            adapter_params_t += txt_params_t
            logger.info(f"Text adapter parameters: {txt_params/1e3:.1f}K. Trainable: {txt_params_t/1e3:.1f}K")
            
        # Handle DDP wrapper for depth encoder
        depth_encoder = self.policy.net.module.depth_encoder if hasattr(self.policy.net, 'module') else self.policy.net.depth_encoder
            
        if getattr(depth_encoder, 'use_depth_adapters', False):
            depth_params = depth_encoder.count_adapter_parameters()
            depth_params_t = sum(p.numel() for p in depth_encoder.get_adapter_parameters() if p.requires_grad)
            adapter_params += depth_params
            adapter_params_t += depth_params_t
            logger.info(f"Depth adapter parameters: {depth_params/1e3:.1f}K. Trainable: {depth_params_t/1e3:.1f}K")
        
        total_params = params + adapter_params
        total_params_t = params_t + adapter_params_t
        
        logger.info(f"Agent parameters: {params/1e6:.2f} MB. Trainable: {params_t/1e6:.2f} MB.")
        logger.info(f"Adapter parameters: {adapter_params/1e3:.1f}K. Trainable: {adapter_params_t/1e3:.1f}K")
        logger.info(f"Total parameters: {total_params/1e6:.2f} MB. Trainable: {total_params_t/1e6:.2f} MB.")
        logger.info("Finished setting up policy.")

        return start_iter

    def _teacher_action(self, batch_angles, batch_distances, candidate_lengths):
        if self.config.MODEL.task_type == 'r2r':
            cand_dists_to_goal = [[] for _ in range(len(batch_angles))]
            oracle_cand_idx = []
            for j in range(len(batch_angles)):
                for k in range(len(batch_angles[j])):
                    angle_k = batch_angles[j][k]
                    forward_k = batch_distances[j][k]
                    dist_k = self.envs.call_at(j, "cand_dist_to_goal", {"angle": angle_k, "forward": forward_k})
                    cand_dists_to_goal[j].append(dist_k)
                curr_dist_to_goal = self.envs.call_at(j, "current_dist_to_goal")
                # if within target range (which def as 3.0)
                if curr_dist_to_goal < 1.5:
                    oracle_cand_idx.append(candidate_lengths[j] - 1)
                else:
                    oracle_cand_idx.append(np.argmin(cand_dists_to_goal[j]))
            return oracle_cand_idx
        elif self.config.MODEL.task_type == 'rxr':
            kargs = []
            current_episodes = self.envs.current_episodes()
            for i in range(self.envs.num_envs):
                kargs.append({
                    'ref_path':self.gt_data[str(current_episodes[i].episode_id)]['locations'],
                    'angles':batch_angles[i],
                    'distances':batch_distances[i],
                    'candidate_length':candidate_lengths[i]
                })
            oracle_cand_idx = self.envs.call(["get_cand_idx"]*self.envs.num_envs, kargs)
            return oracle_cand_idx

    def _teacher_action_new(self, batch_gmap_vp_ids, batch_no_vp_left):
        teacher_actions = []
        cur_episodes = self.envs.current_episodes()
        for i, (gmap_vp_ids, gmap, no_vp_left) in enumerate(zip(batch_gmap_vp_ids, self.gmaps, batch_no_vp_left)):
            curr_dis_to_goal = self.envs.call_at(i, "current_dist_to_goal")
            if curr_dis_to_goal < 1.5:
                teacher_actions.append(0)
            else:
                if no_vp_left:
                    teacher_actions.append(-100)
                elif self.config.IL.expert_policy == 'spl':
                    ghost_vp_pos = [(vp, random.choice(pos)) for vp, pos in gmap.ghost_real_pos.items()]
                    if len(ghost_vp_pos) == 0:
                        # No ghost viewpoints available, use stop action
                        teacher_actions.append(-100)
                    else:
                        ghost_dis_to_goal = [
                            self.envs.call_at(i, "point_dist_to_goal", {"pos": p[1]})
                            for p in ghost_vp_pos
                        ]
                        target_ghost_vp = ghost_vp_pos[np.argmin(ghost_dis_to_goal)][0]
                        teacher_actions.append(gmap_vp_ids.index(target_ghost_vp))
                elif self.config.IL.expert_policy == 'ndtw':
                    ghost_vp_pos = [(vp, random.choice(pos)) for vp, pos in gmap.ghost_real_pos.items()]
                    if len(ghost_vp_pos) == 0:
                        # No ghost viewpoints available, use stop action
                        teacher_actions.append(-100)
                    else:
                        target_ghost_vp = self.envs.call_at(i, "ghost_dist_to_ref", {
                            "ghost_vp_pos": ghost_vp_pos,
                            "ref_path": self.gt_data[str(cur_episodes[i].episode_id)]['locations'],
                        })
                        teacher_actions.append(gmap_vp_ids.index(target_ghost_vp))
                else:
                    raise NotImplementedError
       
        return torch.tensor(teacher_actions).cuda()

    def _vp_feature_variable(self, obs):
        batch_rgb_fts, batch_dep_fts, batch_loc_fts = [], [], []
        batch_nav_types, batch_view_lens = [], []

        for i in range(self.envs.num_envs):
            rgb_fts, dep_fts, loc_fts , nav_types = [], [], [], []
            cand_idxes = np.zeros(12, dtype=np.bool)
            cand_idxes[obs['cand_img_idxes'][i]] = True
            # cand
            rgb_fts.append(obs['cand_rgb'][i])
            dep_fts.append(obs['cand_depth'][i])
            loc_fts.append(obs['cand_angle_fts'][i])
            nav_types += [1] * len(obs['cand_angles'][i])
            # non-cand
            rgb_fts.append(obs['pano_rgb'][i][~cand_idxes])
            dep_fts.append(obs['pano_depth'][i][~cand_idxes])
            loc_fts.append(obs['pano_angle_fts'][~cand_idxes])
            nav_types += [0] * (12-np.sum(cand_idxes))
            
            batch_rgb_fts.append(torch.cat(rgb_fts, dim=0))
            batch_dep_fts.append(torch.cat(dep_fts, dim=0))
            batch_loc_fts.append(torch.cat(loc_fts, dim=0))
            batch_nav_types.append(torch.LongTensor(nav_types))
            batch_view_lens.append(len(nav_types))
        # collate
        batch_rgb_fts = pad_tensors_wgrad(batch_rgb_fts)
        batch_dep_fts = pad_tensors_wgrad(batch_dep_fts)
        batch_loc_fts = pad_tensors_wgrad(batch_loc_fts).cuda()
        batch_nav_types = pad_sequence(batch_nav_types, batch_first=True).cuda()
        batch_view_lens = torch.LongTensor(batch_view_lens).cuda()

        return {
            'rgb_fts': batch_rgb_fts, 'dep_fts': batch_dep_fts, 'loc_fts': batch_loc_fts,
            'nav_types': batch_nav_types, 'view_lens': batch_view_lens,
        }
        
    def _nav_gmap_variable(self, cur_vp, cur_pos, cur_ori):
        batch_gmap_vp_ids, batch_gmap_step_ids, batch_gmap_lens = [], [], []
        batch_gmap_img_fts, batch_gmap_pos_fts = [], []
        batch_gmap_pair_dists, batch_gmap_visited_masks = [], []
        batch_no_vp_left = []

        for i, gmap in enumerate(self.gmaps):
            node_vp_ids = list(gmap.node_pos.keys())
            ghost_vp_ids = list(gmap.ghost_pos.keys())
            if len(ghost_vp_ids) == 0:
                batch_no_vp_left.append(True)
            else:
                batch_no_vp_left.append(False)

            gmap_vp_ids = [None] + node_vp_ids + ghost_vp_ids
            gmap_step_ids = [0] + [gmap.node_stepId[vp] for vp in node_vp_ids] + [0]*len(ghost_vp_ids)
            gmap_visited_masks = [0] + [1] * len(node_vp_ids) + [0] * len(ghost_vp_ids)

            gmap_img_fts = [gmap.get_node_embeds(vp) for vp in node_vp_ids] + \
                           [gmap.get_node_embeds(vp) for vp in ghost_vp_ids]
            gmap_img_fts = torch.stack(
                [torch.zeros_like(gmap_img_fts[0])] + gmap_img_fts, dim=0
            )

            gmap_pos_fts = gmap.get_pos_fts(
                cur_vp[i], cur_pos[i], cur_ori[i], gmap_vp_ids
            )
            gmap_pair_dists = np.zeros((len(gmap_vp_ids), len(gmap_vp_ids)), dtype=np.float32)
            for j in range(1, len(gmap_vp_ids)):
                for k in range(j+1, len(gmap_vp_ids)):
                    vp1 = gmap_vp_ids[j]
                    vp2 = gmap_vp_ids[k]
                    if not vp1.startswith('g') and not vp2.startswith('g'):
                        dist = gmap.shortest_dist[vp1][vp2]
                    elif not vp1.startswith('g') and vp2.startswith('g'):
                        front_dis2, front_vp2 = gmap.front_to_ghost_dist(vp2)
                        dist = gmap.shortest_dist[vp1][front_vp2] + front_dis2
                    elif vp1.startswith('g') and vp2.startswith('g'):
                        front_dis1, front_vp1 = gmap.front_to_ghost_dist(vp1)
                        front_dis2, front_vp2 = gmap.front_to_ghost_dist(vp2)
                        dist = front_dis1 + gmap.shortest_dist[front_vp1][front_vp2] + front_dis2
                    else:
                        raise NotImplementedError
                    gmap_pair_dists[j, k] = gmap_pair_dists[k, j] = dist / MAX_DIST
            
            batch_gmap_vp_ids.append(gmap_vp_ids)
            batch_gmap_step_ids.append(torch.LongTensor(gmap_step_ids))
            batch_gmap_lens.append(len(gmap_vp_ids))
            batch_gmap_img_fts.append(gmap_img_fts)
            batch_gmap_pos_fts.append(torch.from_numpy(gmap_pos_fts))
            batch_gmap_pair_dists.append(torch.from_numpy(gmap_pair_dists))
            batch_gmap_visited_masks.append(torch.BoolTensor(gmap_visited_masks))
        
        # collate
        batch_gmap_step_ids = pad_sequence(batch_gmap_step_ids, batch_first=True).cuda()
        batch_gmap_img_fts = pad_tensors_wgrad(batch_gmap_img_fts)
        batch_gmap_pos_fts = pad_tensors_wgrad(batch_gmap_pos_fts).cuda()
        batch_gmap_lens = torch.LongTensor(batch_gmap_lens)
        batch_gmap_masks = gen_seq_masks(batch_gmap_lens).cuda()
        batch_gmap_visited_masks = pad_sequence(batch_gmap_visited_masks, batch_first=True).cuda()

        bs = self.envs.num_envs
        max_gmap_len = max(batch_gmap_lens)
        gmap_pair_dists = torch.zeros(bs, max_gmap_len, max_gmap_len).float()
        for i in range(bs):
            gmap_pair_dists[i, :batch_gmap_lens[i], :batch_gmap_lens[i]] = batch_gmap_pair_dists[i]
        gmap_pair_dists = gmap_pair_dists.cuda()

        return {
            'gmap_vp_ids': batch_gmap_vp_ids, 'gmap_step_ids': batch_gmap_step_ids,
            'gmap_img_fts': batch_gmap_img_fts, 'gmap_pos_fts': batch_gmap_pos_fts, 
            'gmap_masks': batch_gmap_masks, 'gmap_visited_masks': batch_gmap_visited_masks, 'gmap_pair_dists': gmap_pair_dists,
            'no_vp_left': batch_no_vp_left,
        }

    def _history_variable(self, obs):
        batch_size = obs['pano_rgb'].shape[0]
        hist_rgb_fts = obs['pano_rgb'][:, 0, ...].cuda()
        hist_pano_rgb_fts = obs['pano_rgb'].cuda()
        hist_pano_ang_fts = obs['pano_angle_fts'].unsqueeze(0).expand(batch_size, -1, -1).cuda()

        return hist_rgb_fts, hist_pano_rgb_fts, hist_pano_ang_fts

    @staticmethod
    def _pause_envs(envs, batch, envs_to_pause):
        if len(envs_to_pause) > 0:
            state_index = list(range(envs.num_envs))
            for idx in reversed(envs_to_pause):
                state_index.pop(idx)
                envs.pause_at(idx)
            
            for k, v in batch.items():
                batch[k] = v[state_index]

        return envs, batch

    def train(self):
        self._set_config()
        if self.config.MODEL.task_type == 'rxr':
            self.gt_data = {}
            for role in self.config.TASK_CONFIG.DATASET.ROLES:
                with gzip.open(
                    self.config.TASK_CONFIG.TASK.NDTW.GT_PATH.format(
                        split=self.split, role=role
                    ), "rt") as f:
                    self.gt_data.update(json.load(f))

        observation_space, action_space = self._init_envs()
        start_iter = self._initialize_policy(
            self.config,
            self.config.IL.load_from_ckpt,
            observation_space=observation_space,
            action_space=action_space,
        )

        total_iter = self.config.IL.iters
        log_every  = self.config.IL.log_every
        logger.info('About to create TensorboardWriter...')
        writer     = TensorboardWriter(self.config.TENSORBOARD_DIR if self.local_rank < 1 else None)
        logger.info('Created TensorboardWriter successfully')

        # Only use GradScaler for multi-GPU distributed training
        if self.config.GPU_NUMBERS > 1:
            self.scaler = GradScaler()
        else:
            self.scaler = None
        logger.info('Traning Starts... GOOD LUCK!')
        logger.info('About to enter training loop...')
        try:
            for idx in range(start_iter, total_iter, log_every):
                interval = min(log_every, max(total_iter-idx, 0))
                cur_iter = idx + interval
                
                logger.info(f"Starting training interval: idx={idx}, interval={interval}, cur_iter={cur_iter}")
                sample_ratio = self.config.IL.sample_ratio ** (idx // self.config.IL.decay_interval + 1)
                logger.info(f"About to call _train_interval with ml_weight={self.config.IL.ml_weight}, sample_ratio={sample_ratio}")
                logs = self._train_interval(interval, self.config.IL.ml_weight, sample_ratio)
                logger.info(f"Completed _train_interval, got logs: {list(logs.keys())}")

                if self.local_rank < 1:
                    loss_str = f'iter {cur_iter}: '
                    for k, v in logs.items():
                        logs[k] = np.mean(v)
                        loss_str += f'{k}: {logs[k]:.3f}, '
                        writer.add_scalar(f'loss/{k}', logs[k], cur_iter)
                    logger.info(loss_str)
                    self.save_checkpoint(cur_iter)
        except Exception as e:
            logger.error(f"Training crashed with error: {e}")
            import traceback
            logger.error(f"Traceback: {traceback.format_exc()}")
            raise
        
    def _train_interval(self, interval, ml_weight, sample_ratio):
        self.policy.train()
        if self.world_size > 1:
            self.policy.net.module.rgb_encoder.eval()
            self.policy.net.module.depth_encoder.eval()
        else:
            self.policy.net.rgb_encoder.eval()
            # Handle DDP wrapper for depth encoder
            depth_encoder = self.policy.net.module.depth_encoder if hasattr(self.policy.net, 'module') else self.policy.net.depth_encoder
            depth_encoder.eval()
        self.waypoint_predictor.eval()

        if self.local_rank < 1:
            pbar = tqdm.trange(interval, leave=False, dynamic_ncols=True)
        else:
            pbar = range(interval)
        self.logs = defaultdict(list)

        for idx in pbar:
            # CRITICAL: Clear all gradients and ensure no gradient accumulation
            self.optimizer.zero_grad()
            
            # Clear any cached states that might retain computation graphs
            if hasattr(self, 'loss'):
                del self.loss
            self.loss = 0.
            
            # Update step counter for adaptive corruption scheduling
            self.current_step += 1

            # Run forward pass WITH gradients to enable proper training
            # CRITICAL FIX: Remove torch.no_grad() to allow gradient computation
            with autocast():
                self.rollout('train', ml_weight, sample_ratio)
                
            # Skip if no loss computed
            if not isinstance(self.loss, torch.Tensor):
                continue
                
            # Verify loss has gradients (should be True now)
            if not self.loss.requires_grad:
                logger.warning(f"Loss doesn't require gradients! Loss: {self.loss}")
                continue
            
            # Now do backward with proper gradient flow
            if self.config.GPU_NUMBERS == 1:
                self.loss.backward()
                self.optimizer.step()
            else:
                self.scaler.scale(self.loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            
            # CRITICAL: Explicitly clear gradients again and delete loss reference
            self.optimizer.zero_grad()
            del self.loss

            if self.local_rank < 1:
                pbar.set_postfix({'iter': f'{idx+1}/{interval}'})
            
        return deepcopy(self.logs)

    def eval(self, corruption_type=0):
        """Override base eval method to pass corruption_type to _eval_checkpoint."""
        # Store corruption_type for use in _eval_checkpoint
        self._current_corruption_type = corruption_type
        
        # Call parent eval method which will call our _eval_checkpoint
        super().eval(corruption_type)
        
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
        corruption_type: int = 0,
    ):
        # Use stored corruption_type if available (from eval method)
        if hasattr(self, '_current_corruption_type'):
            corruption_type = self._current_corruption_type
        if self.local_rank < 1:
            logger.info(f"checkpoint_path: {checkpoint_path}")
        self.config.defrost()
        self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        self.config.IL.ckpt_to_load = checkpoint_path
        if self.config.VIDEO_OPTION:
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP_VLNCE")
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("DISTANCE_TO_GOAL")
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("SUCCESS")
            self.config.TASK_CONFIG.TASK.MEASUREMENTS.append("SPL")
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)
            shift = 0.
            orient_dict = {
                'Back': [0, math.pi + shift, 0],            # Back
                'Down': [-math.pi / 2, 0 + shift, 0],       # Down
                'Front':[0, 0 + shift, 0],                  # Front
                'Right':[0, math.pi / 2 + shift, 0],        # Right
                'Left': [0, 3 / 2 * math.pi + shift, 0],    # Left
                'Up':   [math.pi / 2, 0 + shift, 0],        # Up
            }
            sensor_uuids = []
            H = 224
            for sensor_type in ["RGB"]:
                sensor = getattr(self.config.TASK_CONFIG.SIMULATOR, f"{sensor_type}_SENSOR")
                for camera_id, orient in orient_dict.items():
                    camera_template = f"{sensor_type}{camera_id}"
                    camera_config = deepcopy(sensor)
                    camera_config.WIDTH = H
                    camera_config.HEIGHT = H
                    camera_config.ORIENTATION = orient
                    camera_config.UUID = camera_template.lower()
                    camera_config.HFOV = 90
                    sensor_uuids.append(camera_config.UUID)
                    setattr(self.config.TASK_CONFIG.SIMULATOR, camera_template, camera_config)
                    self.config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS.append(camera_template)
        self.config.freeze()

        if self.config.EVAL.SAVE_RESULTS:
            fname = os.path.join(
                self.config.RESULTS_DIR,
                f"stats_ckpt_{checkpoint_index}_{self.config.TASK_CONFIG.DATASET.SPLIT}.json",
            )
            if os.path.exists(fname) and not os.path.isfile(self.config.EVAL.CKPT_PATH_DIR):
                print("skipping -- evaluation exists.")
                return
        self.envs = construct_envs(
            self.config, 
            get_env_class(self.config.ENV_NAME),
            episodes_allowed=self.traj[::5] if self.config.EVAL.fast_eval else self.traj,
            auto_reset_done=False, # unseen: 11006 
        )
        dataset_length = sum(self.envs.number_of_episodes)
        print('local rank:', self.local_rank, '|', 'dataset length:', dataset_length)

        obs_transforms = get_active_obs_transforms(self.config)
        observation_space = apply_obs_transforms_obs_space(
            self.envs.observation_spaces[0], obs_transforms
        )
        self._initialize_policy(
            self.config,
            load_from_ckpt=True,
            observation_space=observation_space,
            action_space=self.envs.action_spaces[0],
        )
        self.policy.eval()
        self.waypoint_predictor.eval()
        
        # === ADAPTERS: Keep adapters in training mode during eval for continued adaptation ===
        if self.img_adapter is not None:
            self.img_adapter.train()
        if self.txt_adapter is not None:
            self.txt_adapter.train()

        if self.config.EVAL.EPISODE_COUNT == -1:
            eps_to_eval = sum(self.envs.number_of_episodes)
        else:
            eps_to_eval = min(self.config.EVAL.EPISODE_COUNT, sum(self.envs.number_of_episodes))
        self.stat_eps = {}
        self.pbar = tqdm.tqdm(total=eps_to_eval) if self.config.use_pbar else None

        while len(self.stat_eps) < eps_to_eval:
            self.rollout('eval', corruption_type=corruption_type)
        self.envs.close()

        if self.world_size > 1:
            distr.barrier()
        aggregated_states = {}
        num_episodes = len(self.stat_eps)
        for stat_key in next(iter(self.stat_eps.values())).keys():
            aggregated_states[stat_key] = (
                sum(v[stat_key] for v in self.stat_eps.values()) / num_episodes
            )
        total = torch.tensor(num_episodes).cuda()
        if self.world_size > 1:
            distr.reduce(total,dst=0)
        total = total.item()

        if self.world_size > 1:
            logger.info(f"rank {self.local_rank}'s {num_episodes}-episode results: {aggregated_states}")
            for k,v in aggregated_states.items():
                v = torch.tensor(v*num_episodes).cuda()
                cat_v = gather_list_and_concat(v,self.world_size)
                v = (sum(cat_v)/total).item()
                aggregated_states[k] = v
        
        split = self.config.TASK_CONFIG.DATASET.SPLIT
        fname = os.path.join(
            self.config.RESULTS_DIR,
            f"stats_ep_ckpt_{checkpoint_index}_{split}_r{self.local_rank}_w{self.world_size}.json",
        )
        with open(fname, "w") as f:
            json.dump(self.stat_eps, f, indent=2)

        if self.local_rank < 1:
            if self.config.EVAL.SAVE_RESULTS:
                fname = os.path.join(
                    self.config.RESULTS_DIR,
                    f"stats_ckpt_{checkpoint_index}_{split}.json",
                )
                with open(fname, "w") as f:
                    json.dump(aggregated_states, f, indent=2)

            logger.info(f"Episodes evaluated: {total}")
            checkpoint_num = checkpoint_index + 1
            for k, v in aggregated_states.items():
                logger.info(f"Average episode {k}: {v:.6f}")
                writer.add_scalar(f"eval_{k}/{split}", v, checkpoint_num)

    def inference(self):
        checkpoint_path = self.config.INFERENCE.CKPT_PATH
        logger.info(f"checkpoint_path: {checkpoint_path}")
        self.config.defrost()
        self.config.IL.ckpt_to_load = checkpoint_path
        self.config.TASK_CONFIG.DATASET.SPLIT = self.config.INFERENCE.SPLIT
        self.config.TASK_CONFIG.DATASET.ROLES = ["guide"]
        self.config.TASK_CONFIG.DATASET.LANGUAGES = self.config.INFERENCE.LANGUAGES
        self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        self.config.TASK_CONFIG.TASK.MEASUREMENTS = ['POSITION_INFER']
        self.config.TASK_CONFIG.TASK.SENSORS = [s for s in self.config.TASK_CONFIG.TASK.SENSORS if "INSTRUCTION" in s]
        self.config.SIMULATOR_GPU_IDS = [self.config.SIMULATOR_GPU_IDS[self.config.local_rank]]
        # if choosing image
        resize_config = self.config.RL.POLICY.OBS_TRANSFORMS.RESIZER_PER_SENSOR.SIZES
        crop_config = self.config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER_PER_SENSOR.SENSOR_CROPS
        task_config = self.config.TASK_CONFIG
        camera_orientations = get_camera_orientations12()
        for sensor_type in ["RGB", "DEPTH"]:
            resizer_size = dict(resize_config)[sensor_type.lower()]
            cropper_size = dict(crop_config)[sensor_type.lower()]
            sensor = getattr(task_config.SIMULATOR, f"{sensor_type}_SENSOR")
            for action, orient in camera_orientations.items():
                camera_template = f"{sensor_type}_{action}"
                camera_config = deepcopy(sensor)
                camera_config.ORIENTATION = camera_orientations[action]
                camera_config.UUID = camera_template.lower()
                setattr(task_config.SIMULATOR, camera_template, camera_config)
                task_config.SIMULATOR.AGENT_0.SENSORS.append(camera_template)
                resize_config.append((camera_template.lower(), resizer_size))
                crop_config.append((camera_template.lower(), cropper_size))
        self.config.RL.POLICY.OBS_TRANSFORMS.RESIZER_PER_SENSOR.SIZES = resize_config
        self.config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER_PER_SENSOR.SENSOR_CROPS = crop_config
        self.config.TASK_CONFIG = task_config
        self.config.SENSORS = task_config.SIMULATOR.AGENT_0.SENSORS
        self.config.freeze()

        torch.cuda.set_device(self.device)
        self.world_size = self.config.GPU_NUMBERS
        self.local_rank = self.config.local_rank
        if self.world_size > 1:
            distr.init_process_group(backend='nccl', init_method='env://')
            self.device = self.config.TORCH_GPU_IDS[self.local_rank]
            torch.cuda.set_device(self.device)
            self.config.defrost()
            self.config.TORCH_GPU_ID = self.config.TORCH_GPU_IDS[self.local_rank]
            self.config.freeze()
        self.traj = self.collect_infer_traj()

        self.envs = construct_envs(
            self.config, 
            get_env_class(self.config.ENV_NAME),
            episodes_allowed=self.traj,
            auto_reset_done=False,
        )

        obs_transforms = get_active_obs_transforms(self.config)
        observation_space = apply_obs_transforms_obs_space(
            self.envs.observation_spaces[0], obs_transforms
        )
        self._initialize_policy(
            self.config,
            load_from_ckpt=True,
            observation_space=observation_space,
            action_space=self.envs.action_spaces[0],
        )
        self.policy.eval()
        self.waypoint_predictor.eval()
        
        # === ADAPTERS: Keep adapters in training mode during inference for continued adaptation ===
        if self.img_adapter is not None:
            self.img_adapter.train()
        if self.txt_adapter is not None:
            self.txt_adapter.train()

        if self.config.INFERENCE.EPISODE_COUNT == -1:
            eps_to_infer = sum(self.envs.number_of_episodes)
        else:
            eps_to_infer = min(self.config.INFERENCE.EPISODE_COUNT, sum(self.envs.number_of_episodes))
        self.path_eps = defaultdict(list)
        self.inst_ids: Dict[str, int] = {}   # transfer submit format
        self.pbar = tqdm.tqdm(total=eps_to_infer)

        while len(self.path_eps) < eps_to_infer:
            self.rollout('infer')
        self.envs.close()

        if self.world_size > 1:
            aggregated_path_eps = [None for _ in range(self.world_size)]
            distr.all_gather_object(aggregated_path_eps, self.path_eps)
            tmp_eps_dict = {}
            for x in aggregated_path_eps:
                tmp_eps_dict.update(x)
            self.path_eps = tmp_eps_dict

            aggregated_inst_ids = [None for _ in range(self.world_size)]
            distr.all_gather_object(aggregated_inst_ids, self.inst_ids)
            tmp_inst_dict = {}
            for x in aggregated_inst_ids:
                tmp_inst_dict.update(x)
            self.inst_ids = tmp_inst_dict


        if self.config.MODEL.task_type == "r2r":
            with open(self.config.INFERENCE.PREDICTIONS_FILE, "w") as f:
                json.dump(self.path_eps, f, indent=2)
            logger.info(f"Predictions saved to: {self.config.INFERENCE.PREDICTIONS_FILE}")
        else:  # use 'rxr' format for rxr-habitat leaderboard
            preds = []
            for k,v in self.path_eps.items():
                # save only positions that changed
                path = [v[0]["position"]]
                for p in v[1:]:
                    if p["position"] != path[-1]: path.append(p["position"])
                preds.append({"instruction_id": self.inst_ids[k], "path": path})
            preds.sort(key=lambda x: x["instruction_id"])
            with jsonlines.open(self.config.INFERENCE.PREDICTIONS_FILE, mode="w") as writer:
                writer.write_all(preds)
            logger.info(f"Predictions saved to: {self.config.INFERENCE.PREDICTIONS_FILE}")

    def get_pos_ori(self):
        pos_ori = self.envs.call(['get_pos_ori']*self.envs.num_envs)
        pos = [x[0] for x in pos_ori]
        ori = [x[1] for x in pos_ori]
        return pos, ori

    def rollout(self, mode, ml_weight=None, sample_ratio=None, corruption_type=0):
        if mode == 'train':
            feedback = 'sample'
        elif mode == 'eval' or mode == 'infer':
            feedback = 'argmax'
        else:
            raise NotImplementedError

        self.envs.resume_all()
        observations = self.envs.reset()
        instr_max_len = self.config.IL.max_text_len # r2r 80, rxr 200
        instr_pad_id = 1 if self.config.MODEL.task_type == 'rxr' else 0
        observations = extract_instruction_tokens(observations, self.config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID,
                                                  max_length=instr_max_len, pad_id=instr_pad_id)
        batch = batch_obs(observations, self.device)
        batch = apply_obs_transforms_batch(batch, self.obs_transforms)
        episode_id = self.envs.current_episodes()[0].episode_id if self.envs.current_episodes() else None
        #batch = self._augment_rgb_in_batch_train(batch, episode_id)

        if mode == 'eval':
            batch = self._augment_depth_in_batch_train(batch, episode_id, corruption_type=corruption_type)
        else:
            batch = self._augment_depth_in_batch_train(batch, episode_id)
        
        if mode == 'eval':
            env_to_pause = [i for i, ep in enumerate(self.envs.current_episodes()) 
                            if ep.episode_id in self.stat_eps]    
            self.envs, batch = self._pause_envs(self.envs, batch, env_to_pause)
            if self.envs.num_envs == 0: return
        if mode == 'infer':
            env_to_pause = [i for i, ep in enumerate(self.envs.current_episodes()) 
                            if ep.episode_id in self.path_eps]    
            self.envs, batch = self._pause_envs(self.envs, batch, env_to_pause)
            if self.envs.num_envs == 0: return
            curr_eps = self.envs.current_episodes()
            for i in range(self.envs.num_envs):
                if self.config.MODEL.task_type == 'rxr':
                    ep_id = curr_eps[i].episode_id
                    k = curr_eps[i].instruction.instruction_id
                    self.inst_ids[ep_id] = int(k)

        # encode instructions
        all_txt_ids = batch['instruction']
        all_txt_masks = (all_txt_ids != instr_pad_id)
        all_txt_embeds = self.policy.net(
            mode='language',
            txt_ids=all_txt_ids,
            txt_masks=all_txt_masks,
        )
        # === ADAPTERS: optional text adapter ===
        if self.adapter_use_text and self.txt_adapter is not None:
            # Handle DDP wrapper
            adapter = self.txt_adapter.module if hasattr(self.txt_adapter, 'module') else self.txt_adapter
            all_txt_embeds = adapter(all_txt_embeds)

        loss = torch.tensor(0.0, device=self.device, requires_grad=True)
        total_actions = 0.
        not_done_index = list(range(self.envs.num_envs))

        have_real_pos = (mode == 'train' or self.config.VIDEO_OPTION)
        ghost_aug = self.config.IL.ghost_aug if mode == 'train' else 0
        self.gmaps = [GraphMap(have_real_pos, 
                               self.config.IL.loc_noise, 
                               self.config.MODEL.merge_ghost,
                               ghost_aug) for _ in range(self.envs.num_envs)]
        prev_vp = [None] * self.envs.num_envs

        for stepk in range(self.max_len):
            total_actions += self.envs.num_envs
            txt_masks = all_txt_masks[not_done_index]
            txt_embeds = all_txt_embeds[not_done_index]
            
            # cand waypoint prediction
            wp_outputs = self.policy.net(
                mode = "waypoint",
                waypoint_predictor = self.waypoint_predictor,
                observations = batch,
                in_train = (mode == 'train' and self.config.IL.waypoint_aug),
            )

            # pano encoder (NOVEL: image adapter + reliability-weighted fusion)
            vp_inputs = self._vp_feature_variable(wp_outputs)
            vp_inputs.update({
                'mode': 'panorama',
            })

            # --- robust image adapter ---
            if self.adapter_use_image and self.img_adapter is not None:
                rgb_fts = vp_inputs['rgb_fts']  # (B, V, D)
                # ensure device
                if rgb_fts.device != self.device:
                    rgb_fts = rgb_fts.to(self.device)
                B, V, D = rgb_fts.shape
                
                # Handle DDP wrapper
                adapter = self.img_adapter.module if hasattr(self.img_adapter, 'module') else self.img_adapter
                
                # Apply robust adapter
                rgb_fts_adapted = adapter(rgb_fts.reshape(B*V, D)).reshape(B, V, D)
                vp_inputs['rgb_fts'] = rgb_fts_adapted  # keep on GPU

            pano_embeds, pano_masks = self.policy.net(**vp_inputs)  # pano_embeds (B,V,Dp), pano_masks (B,V)

            # --- robust reliability-weighted average of views ---
            if self.adapter_use_image and self.img_adapter is not None:
                with torch.no_grad():
                    # build weights from adapted rgb_fts
                    rgb_for_conf = vp_inputs['rgb_fts'].detach()
                    pano_masks_device = pano_masks.to(rgb_for_conf.device)
                    
                    # Use reliability function
                    weights = _reliability_from_features(
                        rgb_for_conf, 
                        pano_masks_device.float(), 
                        alpha=self.adapter_conf_alpha
                    )  # (B,V)
                    
                    # Normalize weights
                    weights = weights / (torch.sum(weights, dim=1, keepdim=True) + 1e-8)
                    
                # Reliability-weighted averaging
                avg_pano_embeds = torch.sum(pano_embeds * weights.unsqueeze(2), 1)
                
                # Log reliability statistics for monitoring
                if mode == 'train' and hasattr(self, 'logs'):
                    # Log current corruption probability
                    progress = min(1.0, self.current_step / self.corruption_warmup_steps)
                    current_prob = self.corruption_start_prob + progress * (self.corruption_end_prob - self.corruption_start_prob)
                    self.logs.setdefault('corruption_prob', []).append(current_prob)
                    
                    # Log reliability statistics
                    avg_weight = torch.mean(weights).item()
                    weight_std = torch.std(weights).item()
                    self.logs.setdefault('reliability_avg', []).append(avg_weight)
                    self.logs.setdefault('reliability_std', []).append(weight_std)
            else:
                # fallback to simple average
                avg_pano_embeds = torch.sum(pano_embeds * pano_masks.unsqueeze(2), 1) / \
                                  torch.sum(pano_masks, 1, keepdim=True)

            # get vp_id, vp_pos of cur_node and cand_node
            cur_pos, cur_ori = self.get_pos_ori()
            cur_vp, cand_vp, cand_pos = [], [], []
            for i in range(self.envs.num_envs):
                cur_vp_i, cand_vp_i, cand_pos_i = self.gmaps[i].identify_node(
                    cur_pos[i], cur_ori[i], wp_outputs['cand_angles'][i], wp_outputs['cand_distances'][i]
                )
                cur_vp.append(cur_vp_i)
                cand_vp.append(cand_vp_i)
                cand_pos.append(cand_pos_i)
            
            if mode == 'train' or self.config.VIDEO_OPTION:
                cand_real_pos = []
                for i in range(self.envs.num_envs):
                    cand_real_pos_i = [
                        self.envs.call_at(i, "get_cand_real_pos", {"angle": ang, "forward": dis})
                        for ang, dis in zip(wp_outputs['cand_angles'][i], wp_outputs['cand_distances'][i])
                    ]
                    cand_real_pos.append(cand_real_pos_i)
            else:
                cand_real_pos = [None] * self.envs.num_envs

            for i in range(self.envs.num_envs):
                cur_embeds = avg_pano_embeds[i]
                cand_embeds = pano_embeds[i][vp_inputs['nav_types'][i]==1]
                self.gmaps[i].update_graph(prev_vp[i], stepk+1,
                                           cur_vp[i], cur_pos[i], cur_embeds,
                                           cand_vp[i], cand_pos[i], cand_embeds,
                                           cand_real_pos[i])

            nav_inputs = self._nav_gmap_variable(cur_vp, cur_pos, cur_ori)
            nav_inputs.update({
                'mode': 'navigation',
                'txt_embeds': txt_embeds,
                'txt_masks': txt_masks,
            })
            no_vp_left = nav_inputs.pop('no_vp_left')
            nav_outs = self.policy.net(**nav_inputs)
            nav_logits = nav_outs['global_logits']
            nav_probs = F.softmax(nav_logits, 1)
            for i, gmap in enumerate(self.gmaps):
                gmap.node_stop_scores[cur_vp[i]] = nav_probs[i, 0].data.item()

            # random sample demo
            # logits = torch.randn(nav_inputs['gmap_masks'].shape).cuda()
            # logits.masked_fill_(~nav_inputs['gmap_masks'], -float('inf'))
            # logits.masked_fill_(nav_inputs['gmap_visited_masks'], -float('inf'))

            if mode == 'train' or self.config.VIDEO_OPTION:
                teacher_actions = self._teacher_action_new(nav_inputs['gmap_vp_ids'], no_vp_left)
            if mode == 'train':
                step_loss = F.cross_entropy(nav_logits, teacher_actions, reduction='sum', ignore_index=-100)
                loss = loss + step_loss
            
            # === ADAPTERS: Compute loss during eval too for adapter training ===
            elif mode == 'eval' and (self.img_adapter is not None or self.txt_adapter is not None):
                teacher_actions = self._teacher_action_new(nav_inputs['gmap_vp_ids'], no_vp_left)
                adapter_step_loss = F.cross_entropy(nav_logits, teacher_actions, reduction='sum', ignore_index=-100)
                loss = loss + adapter_step_loss

            # determine action
            if feedback == 'sample':
                c = torch.distributions.Categorical(nav_probs)
                a_t = c.sample().detach()
                a_t = torch.where(torch.rand_like(a_t, dtype=torch.float)<=sample_ratio, teacher_actions, a_t)
            elif feedback == 'argmax':
                a_t = nav_logits.argmax(dim=-1)
            else:
                raise NotImplementedError
            cpu_a_t = a_t.cpu().numpy()

            # make equiv action
            env_actions = []
            use_tryout = (self.config.IL.tryout and not self.config.TASK_CONFIG.SIMULATOR.HABITAT_SIM_V0.ALLOW_SLIDING)
            for i, gmap in enumerate(self.gmaps):
                if cpu_a_t[i] == 0 or stepk == self.max_len - 1 or no_vp_left[i]:
                    # stop at node with max stop_prob
                    vp_stop_scores = [(vp, stop_score) for vp, stop_score in gmap.node_stop_scores.items()]
                    stop_scores = [s[1] for s in vp_stop_scores]
                    stop_vp = vp_stop_scores[np.argmax(stop_scores)][0]
                    stop_pos = gmap.node_pos[stop_vp]
                    if self.config.IL.back_algo == 'control':
                        back_path = [(vp, gmap.node_pos[vp]) for vp in gmap.shortest_path[cur_vp[i]][stop_vp]]
                        back_path = back_path[1:]
                    else:
                        back_path = None
                    vis_info = {
                            'nodes': list(gmap.node_pos.values()),
                            'ghosts': list(gmap.ghost_aug_pos.values()),
                            'predict_ghost': stop_pos,
                    }
                    env_actions.append(
                        {
                            'action': {
                                'act': 0,
                                'cur_vp': cur_vp[i],
                                'stop_vp': stop_vp, 'stop_pos': stop_pos,
                                'back_path': back_path,
                                'tryout': use_tryout,
                            },
                            'vis_info': vis_info,
                        }
                    )
                else:
                    ghost_vp = nav_inputs['gmap_vp_ids'][i][cpu_a_t[i]]
                    ghost_pos = gmap.ghost_aug_pos[ghost_vp]
                    _, front_vp = gmap.front_to_ghost_dist(ghost_vp)
                    front_pos = gmap.node_pos[front_vp]
                    if self.config.VIDEO_OPTION:
                        teacher_action_cpu = teacher_actions[i].cpu().item()
                        if teacher_action_cpu in [0, -100]:
                            teacher_ghost = None
                        else:
                            teacher_ghost = gmap.ghost_aug_pos[nav_inputs['gmap_vp_ids'][i][teacher_action_cpu]]
                        vis_info = {
                            'nodes': list(gmap.node_pos.values()),
                            'ghosts': list(gmap.ghost_aug_pos.values()),
                            'predict_ghost': ghost_pos,
                            'teacher_ghost': teacher_ghost,
                        }
                    else:
                        vis_info = None
                    # teleport to front, then forward to ghost
                    if self.config.IL.back_algo == 'control':
                        back_path = [(vp, gmap.node_pos[vp]) for vp in gmap.shortest_path[cur_vp[i]][front_vp]]
                        back_path = back_path[1:]
                    else:
                        back_path = None
                    env_actions.append(
                        {
                            'action': {
                                'act': 4,
                                'cur_vp': cur_vp[i],
                                'front_vp': front_vp, 'front_pos': front_pos,
                                'ghost_vp': ghost_vp, 'ghost_pos': ghost_pos,
                                'back_path': back_path,
                                'tryout': use_tryout,
                            },
                            'vis_info': vis_info,
                        }
                    )
                    prev_vp[i] = front_vp
                    if self.config.MODEL.consume_ghost:
                        gmap.delete_ghost(ghost_vp)

            outputs = self.envs.step(env_actions)
            observations, _, dones, infos = [list(x) for x in zip(*outputs)]

            # calculate metric
            if mode == 'eval':
                curr_eps = self.envs.current_episodes()
                for i in range(self.envs.num_envs):
                    if not dones[i]:
                        continue
                    info = infos[i]
                    ep_id = curr_eps[i].episode_id
                    gt_path = np.array(self.gt_data[str(ep_id)]['locations']).astype(np.float)
                    pred_path = np.array(info['position']['position'])
                    distances = np.array(info['position']['distance'])
                    metric = {}
                    metric['steps_taken'] = info['steps_taken']
                    metric['distance_to_goal'] = distances[-1]
                    metric['success'] = 1. if distances[-1] <= 3. else 0.
                    metric['oracle_success'] = 1. if (distances <= 3.).any() else 0.
                    metric['path_length'] = float(np.linalg.norm(pred_path[1:] - pred_path[:-1],axis=1).sum())
                    gt_length = distances[0]
                    metric['spl'] = metric['success'] * gt_length / max(gt_length, metric['path_length'])
                    dtw_distance = fastdtw(pred_path, gt_path, dist=NDTW.euclidean_distance)[0]
                    metric['ndtw'] = np.exp(-dtw_distance / (len(gt_path) * 3.))
                    metric['sdtw'] = metric['ndtw'] * metric['success']
                    metric['ghost_cnt'] = self.gmaps[i].ghost_cnt
                    self.stat_eps[ep_id] = metric
                    self.pbar.update()

            # record path
            if mode == 'infer':
                curr_eps = self.envs.current_episodes()
                for i in range(self.envs.num_envs):
                    if not dones[i]:
                        continue
                    info = infos[i]
                    ep_id = curr_eps[i].episode_id
                    self.path_eps[ep_id] = [
                        {
                            'position': info['position_infer']['position'][0],
                            'heading': info['position_infer']['heading'][0],
                            'stop': False
                        }
                    ]
                    for p, h in zip(info['position_infer']['position'][1:], info['position_infer']['heading'][1:]):
                        if p != self.path_eps[ep_id][-1]['position']:
                            self.path_eps[ep_id].append({
                                'position': p,
                                'heading': h,
                                'stop': False
                            })
                    self.path_eps[ep_id] = self.path_eps[ep_id][:500]
                    self.path_eps[ep_id][-1]['stop'] = True
                    self.pbar.update()

            # pause env
            if sum(dones) > 0:
                for i in reversed(list(range(self.envs.num_envs))):
                    if dones[i]:
                        not_done_index.pop(i)
                        self.envs.pause_at(i)
                        observations.pop(i)
                        # graph stop
                        self.gmaps.pop(i)
                        prev_vp.pop(i)

            if self.envs.num_envs == 0:
                break

            # obs for next step
            observations = extract_instruction_tokens(observations,self.config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
            batch = batch_obs(observations, self.device)
            batch = apply_obs_transforms_batch(batch, self.obs_transforms)

            episode_id = self.envs.current_episodes()[0].episode_id if self.envs.current_episodes() else None
            #batch = self._augment_rgb_in_batch_train(batch, episode_id)
            if mode == 'eval':
                batch = self._augment_depth_in_batch_train(batch, episode_id, corruption_type=corruption_type)
            else:
                batch = self._augment_depth_in_batch_train(batch, episode_id)

        if mode == 'train':
            loss = ml_weight * loss / total_actions
            self.loss = loss  # Directly assign, don't accumulate
            self.logs['IL_loss'].append(loss.item())
        
        # === ADAPTERS: Adapter updates during evaluation are disabled ===
        # Analysis showed that pre-trained adapters work better without online updates during eval
        # Online updates were causing overfitting and catastrophic forgetting
        # 
        # DISABLED CODE BLOCK - all lines properly commented out:
        # elif mode == 'eval' and loss > 0 and (self.img_adapter is not None or self.txt_adapter is not None):
        #     # Normalize loss similar to training
        #     adapter_loss = loss / total_actions
        #     
        #     # Only proceed if the loss tensor requires gradients
        #     if adapter_loss.requires_grad:
        #         # Compute gradients only for adapter parameters
        #         adapter_loss.backward()
        #         
        #         # Update only adapter parameters
        #         if hasattr(self, 'optimizer'):
        #             # Only step parameters that have gradients (adapters should have them)
        #             self.optimizer.step()
        #             self.optimizer.zero_grad()
        #     else:
        #         # If loss doesn't require grad, it means adapters weren't involved in computation
        #         # This can happen if adapters are not being used properly
        #         import logging
        #         logger = logging.getLogger(__name__)
        #         logger.warning("Adapter loss doesn't require gradients - adapters may not be properly connected to the computation graph")
