"""Given an objective, learn the coefficients and singular values 
on SVD components of merged task vectors for a dataset.

Features:
- Uses merge_task_vectors to precompute and select SVD components (U, S, Vh)
- Learns selected singular values and layer-wise coefficients
- Maintains original parameter dtype and shape information
"""

import os
import time
import json
import torch
import torchvision
import sys
import argparse

# Remove the wandb environment variables that disable functionality
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

# Redirect wandb stderr to /dev/null before import
old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel
from src.composition import TopValuesTaskBasedSVDWeightedImageEncoder
from src.composition import MergedTaskVectorImageEncoder
from torch import nn

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

# Define parameters to be excluded from SVD processing
SKIP_PARAMS = [
    "positional_embedding", "visual.positional_embedding",
    "visual.proj", "token_embedding", "visual.proj", "text_projection", "token_embedding.weight",
    "model.positional_embedding", "model.visual.positional_embedding", "model.token_embedding.weight",
    "model.visual.proj", "model.token_embedding", "model.visual.proj", "model.text_projection"
]

def corrupt_task_vector(task_vector, skip_params, noise_level=0.5):
    """
    Corrupts a task vector by adding layer-wise Gaussian noise to SVD-eligible layers.

    The noise for each 2D layer (not in skip_params) has a standard deviation
    scaled by the Frobenius norm of that layer's weight delta.

    Args:
        task_vector: The TaskVector object to corrupt.
        skip_params (list): A list of parameter names to ignore.
        noise_level (float): The scaling factor for the noise standard deviation.
                             sigma = noise_level * ||delta||_F.

    Returns:
        The corrupted TaskVector object.
    """
    if is_main_process():
        print(f"Corrupting a task vector with noise level (sigma scale): {noise_level}")
    
    with torch.no_grad():
        for key, delta in task_vector.vector.items():
            # Corrupt only 2D layers that are not in the skip list
            if key in skip_params or delta.dim() != 2:
                continue

            # This part now only executes for 2D tensors not in SKIP_PARAMS
            if delta.is_floating_point():
                norm = torch.linalg.norm(delta.float(), 'fro')
                sigma = noise_level * norm
                # Ensure sigma is not NaN or inf, can happen if norm is 0
                if torch.isfinite(sigma) and sigma > 0:
                    noise = torch.randn_like(delta.float()) * sigma
                    task_vector.vector[key] += noise.to(delta.dtype)
    
    if is_main_process():
        print("Task vector corruption complete.")
        
    return task_vector

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
    # If x is None or p is None, return 0
    if x is None or p is None:
        return 0
    # For SVDWeightedImageEncoder, regularization is on singular values, not coefficients
    if not x.requires_grad:
        return 0
    return gamma * torch.norm(x, p=p, dim=0).mean()

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For completely deterministic results, set the following flags
    # Note: This may slow down training
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")

def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    # Shuffle the port list to avoid always trying the same ports first
    port_list = list(port_list)  # Create a copy to avoid modifying the original
    random.shuffle(port_list)
    
    for port in port_list:
        # Check with TCP socket
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                # Double check by waiting a moment and trying again
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

def merge_task_vectors(base_model_params_dict, task_vectors, config):
    """
    Merges task vectors by collecting all SVD components from all tasks for each layer,
    globally sorting them by singular value, and selecting a subset based on a threshold
    relative to a single task's component count.

    For each 2D parameter:
    1. For each task, perform SVD on the task's delta matrix.
    2. Collect all singular components (U, S, Vh) from all tasks into a single list for the layer.
    3. Sort this global list of components based on their singular values.
    4. Select the top K components, where K is determined by `config.svd_threshold`
       applied to the number of components from a *single* task.
    5. The selected U, S, and Vh components are returned directly, without a second SVD.
    Non-2D layers' deltas are averaged.
    
    Returns:
        dict: Layer components with keys:
            - 'U', 'S', 'Vh', 'is_svd' (True) - for SVD layers
            - 'tensor', 'is_svd' (False) - for non-SVD layers
            - Metadata: original_dtype, original_shape, num_selected_components
    """
    if not task_vectors:
        print("Warning: No task vectors provided to merge_task_vectors.")
        return None

    print(f"Computing global SVD component selection for {len(task_vectors)} task vectors...")
    
    with torch.no_grad():
        new_vector = {}
        for task_key in task_vectors[0].vector:
            # Boilerplate to get device, dtype, shape
            base_key = task_key
            if task_key.startswith("model."):
                base_key = task_key[len("model."):]

            if base_key not in base_model_params_dict:
                print(f"Warning: Base key '{base_key}' (from task key '{task_key}') not in base_model_params_dict. Skipping.")
                continue

            current_device = base_model_params_dict[base_key].device
            original_dtype = task_vectors[0].vector[task_key].dtype
            original_shape = task_vectors[0].vector[task_key].shape

            # Skip layers - USES GLOBAL SKIP_PARAMS NOW
            if len(original_shape) != 2 or "text_projection" in task_key or task_key in SKIP_PARAMS:
                tensor_val = None
                if len(task_vectors) == 1:
                    tensor_val = task_vectors[0].vector[task_key].to(current_device)
                else:
                    tvs = [tv.vector[task_key].to(current_device) for tv in task_vectors]
                    tensor_val = sum(tvs) / len(tvs)
                new_vector[task_key] = {
                    'tensor': tensor_val, 'is_svd': False,
                    'original_dtype': original_dtype, 'original_shape': original_shape
                }
                continue

            try:
                # ---  Global component selection ---
                all_components = []
                # 1. Collect all components from all tasks for the current layer
                for task_idx, task_vector in enumerate(task_vectors):
                    delta_matrix = task_vector.vector[task_key].to(device=current_device, dtype=torch.float32)
                    U_task, S_task, Vh_task = torch.linalg.svd(delta_matrix, full_matrices=False)
                    
                    for i in range(S_task.shape[0]):
                        all_components.append({
                            "s_value": S_task[i],
                            "u_vector": U_task[:, i],
                            "vh_vector": Vh_task[i, :],
                            "source_task_idx": task_idx
                        })
                
                if not all_components:
                    print(f"Warning: No SVD components generated for key {task_key}. Skipping.")
                    continue

                # 2. Sort the global list of components
                all_components.sort(key=lambda x: x["s_value"], reverse=config.sorting_descending)
                
                # 3. Determine how many components to keep.
                # Threshold is relative to the number of components in a *single* task.
                # max_singular_values_per_task = min(original_shape)
                # num_to_keep_old = int(config.svd_threshold * max_singular_values_per_task)
                
                #  Keep a fixed number of components (e.g., 768), but not more than available.
                num_to_keep = 76
                num_to_keep = min(num_to_keep, len(all_components))

                # Ensure at least one component is kept if threshold > 0 and rounding caused 0
                if num_to_keep == 0 and config.svd_threshold > 0:
                    print(f"Warning for {task_key}: num_to_keep is 0 with svd_threshold {config.svd_threshold}. Model will handle 0 components.")

                # 4. Select the top components from the globally sorted list
                selected_components = all_components[:num_to_keep]
                kept_component_count = len(selected_components)

                # 5. Assemble final U, S, Vh from selected components
                if kept_component_count == 0:
                    print(f"Warning: No singular value components selected for key {task_key}. Creating zero SVD components.")
                    U_final = torch.empty((original_shape[0], 0), dtype=torch.float32, device=current_device)
                    S_final = torch.empty((0,), dtype=torch.float32, device=current_device)
                    Vh_final = torch.empty((0, original_shape[1]), dtype=torch.float32, device=current_device)
                else:
                    S_final = torch.stack([comp["s_value"] for comp in selected_components])
                    U_final = torch.stack([comp["u_vector"] for comp in selected_components], dim=1)
                    Vh_final = torch.stack([comp["vh_vector"] for comp in selected_components], dim=0)

                # Store the final components
                new_vector[task_key] = {
                    'U': U_final.to(original_dtype),
                    'S': S_final.to(original_dtype),
                    'Vh': Vh_final.to(original_dtype),
                    'is_svd': True,
                    'original_dtype': original_dtype,
                    'original_shape': original_shape,
                    'num_selected_components': kept_component_count
                }
                
                print(f"Global SVD components extracted for key '{task_key}', original shape: {original_shape}")
                print(f"  Total components from all tasks: {len(all_components)}")
                print(f"  Selected {kept_component_count} components (fixed selection).")
                print(f"  Final U shape: {U_final.shape}, Final S shape: {S_final.shape}, Final Vh shape: {Vh_final.shape}")

            except Exception as e:
                print(f"Error: Global SVD component selection failed for key '{task_key}' with error: {str(e)}")
                import traceback
                traceback.print_exc()
                print(f"Skipping key {task_key} due to SVD error.")
                continue

    return new_vector

def create_task_vector_from_merged(merged_vector, task_vectors, args):
    """(Note: Primarily for compatibility - expects reconstructed tensors)
    Creates TaskVector from traditional merged vector format.
    May require reconstruction from components for new merge_task_vectors outputs.
    """
    # Determine the type of task vector to create based on the original vectors
    if args.finetuning_mode == "linear":
        return LinearizedTaskVector(vector=merged_vector, use_half=not args.no_use_half)
    else:
        return NonLinearTaskVector(vector=merged_vector, use_half=not args.no_use_half)

class LearnableSingularValuesMergedEncoder(nn.Module):
    """Learns singular values for precomputed SVD components.
    
    Uses components from merge_task_vectors instead of performing new SVD:
    - Stores U/Vh as buffers
    - Initializes selected S values to 0.0 and makes them learnable
    - Handles non-SVD layers through direct delta tensors
    
    Args:
        merged_vector_components: Output from merge_task_vectors containing:
            - SVD components for 2D layers
            - Direct tensors for other layers
    """
    def __init__(self, model, merged_vector_components, args):
        super().__init__()
        
        self.model = model
        self.args = args
        
        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir
        
        from functorch import make_functional_with_buffers
        func, params_from_functional, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = nn.ParameterList(params_from_functional)
        for p in self.params:
            p.requires_grad = False
            
        self.param_names = [name for name, _ in model.named_parameters()]

        self.svd_components_info = {}
        self.learnable_s_values = nn.ParameterDict()
        self.direct_deltas = {}
        self.total_learnable_sv = 0
        
        print("\n--- Initializing LearnableSingularValuesMergedEncoder ---")
        print("Processing parameters based on pre-computed SVD components or direct deltas:")
        
        for param_idx, param_name in enumerate(self.param_names):
            if param_name in merged_vector_components:
                layer_data = merged_vector_components[param_name]
                svd_key_safe_name = param_name.replace('.', '_')
                
                original_dtype = layer_data['original_dtype']
                original_shape = layer_data['original_shape']

                if layer_data['is_svd']:
                    U_from_isoc = layer_data['U']
                    S_from_isoc = layer_data['S']
                    Vh_from_isoc = layer_data['Vh']
                    num_components_from_isoc = layer_data['num_selected_components']

                    if num_components_from_isoc > 0:
                        # --- NEW STEP 1: Reconstruct the matrix from merge_task_vectors components and perform a new SVD ---
                        print(f"  SVD layer {param_name}: Reconstructing delta from {num_components_from_isoc} components...")
                        reconstructed_delta = U_from_isoc.to(torch.float32) @ torch.diag_embed(S_from_isoc.to(torch.float32)) @ Vh_from_isoc.to(torch.float32)
                        
                        print(f"  SVD layer {param_name}: Performing second SVD on reconstructed delta...")
                        U_new, S_new, Vh_new = torch.linalg.svd(reconstructed_delta, full_matrices=False)
                        # The new components (U_new, S_new, Vh_new) now form the basis for this layer.

                        #   Split S_new into learnable and frozen parts ---
                        num_total_new_components = S_new.shape[0]
                        num_to_learn = int(self.args.svd_threshold * num_total_new_components)

                        # Split the new S values
                        S_initial_learnable = S_new[:num_to_learn]
                        S_initial_frozen = S_new[num_to_learn:]
                        
                        self.register_buffer(f'U_{svd_key_safe_name}', U_new)
                        self.register_buffer(f'Vh_{svd_key_safe_name}', Vh_new)
                        
                        self.register_buffer(f'initial_selected_S_{svd_key_safe_name}', S_initial_learnable)
                        self.register_buffer(f'frozen_S_{svd_key_safe_name}', S_initial_frozen)
                        
                        learnable_S_for_layer = torch.full_like(S_initial_learnable, 0.0, dtype=torch.float32)
                        self.learnable_s_values[svd_key_safe_name] = nn.Parameter(learnable_S_for_layer)
                        
                        self.total_learnable_sv += num_to_learn
                        
                        self.svd_components_info[param_name] = {
                            'original_dtype': original_dtype,
                            'original_shape': original_shape,
                            'num_learnable': num_to_learn,
                            'num_frozen': len(S_initial_frozen),
                            'num_total_components': num_total_new_components,
                        }
                        print(f"  SVD layer {param_name} (Shape {original_shape}): "
                              f"Re-SVD created {num_total_new_components} new components. "
                              f"Learning top {num_to_learn} ({self.args.svd_threshold*100:.1f}%) singular values. "
                              f"Freezing remaining {len(S_initial_frozen)}.")
                    else:
                        print(f"  SVD layer {param_name} (Shape {original_shape}, Dtype {original_dtype}): "
                              f"No SVD components selected by merge_task_vectors. Delta will be zero.")
                        self.svd_components_info[param_name] = { 
                            'original_dtype': original_dtype,
                            'original_shape': original_shape,
                            'num_learnable': 0,
                            'num_frozen': 0,
                            'num_total_components': 0,
                        }
                        device = self.params[param_idx].device
                        self.register_buffer(f'U_{svd_key_safe_name}', torch.empty((original_shape[0], 0), dtype=torch.float32, device=device))
                        self.register_buffer(f'Vh_{svd_key_safe_name}', torch.empty((0, original_shape[1]), dtype=torch.float32, device=device))
                        self.register_buffer(f'frozen_S_{svd_key_safe_name}', torch.empty((0,), dtype=torch.float32, device=device))
                else:
                    direct_delta_tensor = layer_data['tensor']
                    self.register_buffer(f'direct_delta_{svd_key_safe_name}', direct_delta_tensor)
                    self.direct_deltas[param_name] = {
                         'original_dtype': original_dtype,
                         'key_for_buffer': f'direct_delta_{svd_key_safe_name}'
                    }
                    print(f"  Non-SVD layer {param_name} (Shape {direct_delta_tensor.shape if isinstance(direct_delta_tensor, torch.Tensor) else 'N/A'}, Dtype {direct_delta_tensor.dtype if isinstance(direct_delta_tensor, torch.Tensor) else 'N/A'}): Using direct delta.")
            else:
                print(f"  No SVD components or direct delta in merged_vector_components for {param_name}. Update for this layer will be scaled by multiplier but start at zero.")
        
        print(f"Total learnable singular values across all layers: {self.total_learnable_sv}")
        print("--- Initialization Complete --- \n")
        
        # Add a dummy parameter to avoid DDP issues with models with no parameters
        # self.dummy_param = nn.Parameter(torch.zeros(1))

    def _apply(self, fn):
        """Override method to relocate buffer list"""
        # Apply to nn.Module's parameters and registered buffers first
        # This handles the learnable S-values and various buffers for SVD components (U, Vh, initial S) and direct deltas.
        new_self = super()._apply(fn=fn)
        
        # Apply to functorch's buffers
        if hasattr(new_self, 'buffer') and new_self.buffer is not None:
            new_self.buffer = tuple(fn(b) for b in new_self.buffer)

        # Apply to other tensors we manage: self.merged_vector (dictionary of tensors)
        # and self.params (ParameterList of frozen base parameters)
        # merged_vector is no longer a direct attribute holding tensors in the same way.
        # self.direct_deltas now just holds metadata, the tensors are buffers.
        # if hasattr(new_self, 'merged_vector') and new_self.merged_vector is not None:
        #     new_merged_vector_dict = {}
        #     for k, v_tensor in new_self.merged_vector.items():
        #         if isinstance(v_tensor, torch.Tensor):
        #             new_merged_vector_dict[k] = fn(v_tensor)
        #         else:
        #             new_merged_vector_dict[k] = v_tensor # Keep non-tensors as is
        #     new_self.merged_vector = new_merged_vector_dict
        
        if hasattr(new_self, 'params') and isinstance(new_self.params, nn.ParameterList):
            for i in range(len(new_self.params)):
                if new_self.params[i] is not None: # Should always be a tensor
                     new_self.params[i].data = fn(new_self.params[i].data)


        return new_self
    
    def train(self, mode=True):
        super().train(mode)

    def forward(self, x):
        """Forward pass applying the merged vector with learnable singular values to the model parameters."""
        final_model_params = [] # These will be passed to self.func

        for param_idx, param_name in enumerate(self.param_names):
            base_param = self.params[param_idx] # Current base weight
            
            actual_delta_c: torch.Tensor
            svd_key_safe_name = param_name.replace('.', '_')

            if param_name in self.svd_components_info:
                svd_info = self.svd_components_info[param_name]
                num_total_components = svd_info.get('num_total_components', 0)
                original_delta_dtype = svd_info['original_dtype']

                if num_total_components > 0:
                    U = getattr(self, f'U_{svd_key_safe_name}')
                    Vh = getattr(self, f'Vh_{svd_key_safe_name}')
                    
                    # Get learnable and frozen S parts
                    learnable_S_values = self.learnable_s_values.get(svd_key_safe_name)
                    frozen_S_values = getattr(self, f'frozen_S_{svd_key_safe_name}', None)
                    
                    # Build the full S vector by concatenating learnable and frozen parts
                    s_parts = []
                    if learnable_S_values is not None and learnable_S_values.numel() > 0:
                        s_parts.append(learnable_S_values)
                    if frozen_S_values is not None and frozen_S_values.numel() > 0:
                        s_parts.append(frozen_S_values)
                    
                    if not s_parts:
                        reconstructed_delta = torch.zeros(svd_info['original_shape'], device=base_param.device)
                    else:
                        full_S_vector = torch.cat(s_parts, dim=0)

                        # Explicitly cast to float32 before reconstruction
                        U_f32 = U.to(torch.float32)
                        Vh_f32 = Vh.to(torch.float32)
                        full_S_f32 = full_S_vector.to(torch.float32)
                        
                        # Reconstruct: U @ diag(S_vector) @ Vh
                        reconstructed_delta = U_f32 @ torch.diag_embed(full_S_f32) @ Vh_f32
                    
                    actual_delta_c = reconstructed_delta.to(original_delta_dtype) # Convert back to original delta's dtype
                else:
                    # No learnable SVD components for this layer (e.g. merge_task_vectors selected 0)
                    actual_delta_c = torch.zeros_like(base_param)
            
            elif f'direct_delta_{svd_key_safe_name}' in self._buffers: # Check if buffer exists
                # Non-SVD layer, use direct delta from buffer
                direct_delta_tensor = getattr(self, f'direct_delta_{svd_key_safe_name}')
                actual_delta_c = direct_delta_tensor # Already has correct dtype from buffer
            else:
                # No SVD components and no direct delta found for this parameter.
                # This implies it wasn't in merged_vector_components from merge_task_vectors.
                # Delta is effectively zero.
                actual_delta_c = torch.zeros_like(base_param)

            # Add delta_c to the base parameter
            final_model_params.append(base_param + actual_delta_c)
        
        # Apply the function with the modified parameters
        return self.func(final_model_params, self.buffer, x)

def main(rank, args):
    # First set up distributed processing
    args.rank = rank
    
    # Track if distributed is initialized
    distributed_initialized = False
    
    try:
        # Use a different port to avoid "Address already in use" errors
        # Define a list of available ports to choose from
        available_ports = list(range(29520, 29590))
        # Use the port from args if specified, otherwise find first available port
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            # Try to find an available port
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        # Initialize distributed processing
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        print("no_use_half", args.no_use_half)
        # Check if we're using isoc mode with half precision, which requires conversion
        if hasattr(args, 'isoc') and args.isoc and not args.no_use_half:
            if is_main_process():
                print("\n")
                print("*" * 80)
                print("WARNING: ISO-C requires full precision for SVD operations.")
                print("Automatically switching to full precision mode (--no-use-half).")
                print("*" * 80)
                print("\n")
            args.no_use_half = True
        
        # Then set the random seed for reproducibility
        if args.seed is not None:
            set_seed(args.seed)
        
        # Load the individual task vectors.
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]
        # task_vectors = {}
        # for dataset in pool:
        #     if args.finetuning_mode == "linear":
        #         pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
        #         finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
        #         task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
        #     else:
        #         pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
        #         finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
        #         task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()  # Make a copy to avoid modifying the original list
        
        # Loop over each dataset up to the full set
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx >= args.end_index:
                print(f"Ending iteration {idx+1} because end-index is {args.end_index}")
                break

            try:
                # For each iteration, use datasets from 0 to idx (inclusive)
                args.datasets = original_datasets[:idx+1]

                # Print information about task vector precision mode
                if is_main_process():
                    precision_mode = "half precision (float16)" if not args.no_use_half else "full precision (float32)"
                    print(f"Creating task vectors using {precision_mode}")
                    print(f"SVD learning mode: learning top {args.svd_threshold*100:.1f}% singular values for each layer")
                    if args.svd_threshold > 0 and args.no_use_half:
                        print(f"Using SVD thresholding (in merge_task_vectors) with threshold {args.svd_threshold}")
                        if args.keep_top_values:
                            print(f"Mode: KEEP top {args.svd_threshold*100}% singular values, ZERO OUT the rest")
                        else:
                            print(f"Mode: ZERO OUT top {args.svd_threshold*100}% singular values, KEEP the rest")
                    elif args.svd_threshold > 0 and not args.no_use_half:
                        print(f"WARNING: SVD thresholding requires full precision. Use --no-use-half flag for SVD operations.")

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning SVD-merged task vector singular values on {args.target_dataset} with {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                # Print all command line arguments
                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                if "CUDA out of memory" in str(e):
                    print("CUDA out of memory error detected - stopping the entire job!")
                    if distributed_initialized:
                        cleanup_ddp()
                    import sys
                    sys.exit(1)
            
    finally:
        # Only cleanup distributed processing at the very end
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    # Restore the original datasets after all iterations are complete
    args.datasets = original_datasets


def train(task_vectors, args):
    # Set seed for this process to ensure deterministic behavior

    # Track the experiment start time
    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    # Format subsample value for filename/run name
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    # Initialize comp_acc dictionary to avoid undefined variable error
    comp_acc = {}
    

    if args.seed is not None:
        set_seed(args.seed + args.rank)  # Add rank to avoid identical random values across processes

    target_dataset = args.target_dataset
    
    
    # Remove the task vector for the target task
    task_vectors_list = [v for k, v in task_vectors.items() if target_dataset.replace("Val", "") != k]
    num_source_tasks = len(task_vectors_list)

    # --- NEW: Corrupt the last task vector for stress-testing ---
    # This happens when the number of source task vectors is between 3 and 8 (inclusive).
    if 3 <= num_source_tasks <= 8:
        if is_main_process():
            print("\n" + "="*50)
            print(f"STRESS TEST: Corrupting the last of {num_source_tasks} source task vectors.")
            print(f"The vector to be corrupted corresponds to the last task in the source list.")
            print("="*50 + "\n")
        
        # Corrupt the last task vector in the list.
        # The order depends on args.datasets, which is consistent.
        task_to_corrupt = task_vectors_list[-1]
        task_vectors_list[-1] = corrupt_task_vector(task_to_corrupt, SKIP_PARAMS, noise_level=0.5)

    # --- Create base_image_encoder to get its parameters *before* calling merge_task_vectors ---
    # This is also done before wandb.init as per original comment structure
    if args.finetuning_mode == "linear":
        # Linear mode still not implemented for this approach
        if is_main_process():
            print("Error: Linear mode with singular value learning not implemented.")
            if wandb.run is not None:
                wandb.finish(quiet=True)
        return
    else:
        # Create the base encoder first
        base_image_encoder = ImageEncoder(args)

    # Prepare base_model_params_dict for merge_task_vectors
    # Using state_dict ensures we get all persistent tensors by name
    # Cloning to ensure merge_task_vectors doesn't accidentally modify the model's actual parameters
    base_model_params_dict = {
        name: tensor.clone() 
        for name, tensor in base_image_encoder.model.state_dict().items()
    }
    
    # Use merge_task_vectors to merge the task vectors
    if is_main_process():
        print("\n" + "="*50)
        print("ISO-C LEARN MODE: Merging task vectors with SVD-based component selection")
        print(f"SVD performed on (base_param + task_delta) for each task.")
        print(f"Learning top {args.svd_threshold*100:.1f}% singular values for each merged layer")
        print("="*50 + "\n")
    
    # Merge the task vectors into a single vector description
    merged_vector_components = merge_task_vectors(base_model_params_dict, task_vectors_list, args)
    
    if merged_vector_components is None:
        if is_main_process():
            print("Error: Failed to merge task vectors with SVD.")
            # Ensure wandb is finished even if initialization failed
            if wandb.run is not None:
                wandb.finish(quiet=True)
        return
    

    # base_image_encoder is already defined from above
    # Then create the LearnableSingularValuesMergedEncoder wrapper
    # Pass the components directly from merge_task_vectors
    image_encoder = LearnableSingularValuesMergedEncoder(
        base_image_encoder, 
        merged_vector_components, # Use the output from the corrected merge_task_vectors call
        args
    )
    # ------------------------------------------------------------------------

    # Initialize wandb in the main process (after distributed setup and encoder creation)
    if is_main_process():
        # Collect system information
        gpu_info = {}
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        # Use WANDB_API_KEY environment variable if possible, or fall back to hardcoded key
        # Better to set this in your environment or via command line: export WANDB_API_KEY=your_api_key
        api_key = os.environ.get("WANDB_API_KEY", "your-wandb-key")
        wandb.login(key=api_key)
        
        # Add more descriptive run name with SVD info
        threshold_pct = int(args.svd_threshold * 100)
        # Removed fixed_values from run name, using threshold percentage
        run_name = f"{args.model}_{target_dataset}-isoc-learn-svd-top{threshold_pct}pct-e{args.epochs}-nbdts{len(args.datasets)}"
        
        # Add SVD thresholding info (from merge_task_vectors) to run name if enabled
        if hasattr(args, 'svd_threshold') and args.svd_threshold > 0:
            merge_task_vectors_threshold_pct = int(args.svd_threshold * 100)
            merge_task_vectors_zeroing_mode = "keepTop" if args.keep_top_values else "zeroTop"
            # Distinguish merge_task_vectors thresholding mode in name if different from learning threshold (though currently they use the same arg)
            run_name += f"-merge_task_vectors_{merge_task_vectors_zeroing_mode}{merge_task_vectors_threshold_pct}pct"
        
        # Add subsample info to run name
        run_name += f"-{subsample_str}"
        
        wandb.init(
            project="axis",
            entity="example-owner",
            name=run_name,
            config={
                # Basic experiment config
                "type": "learn_coef",
                "genre": "isoc_svd_components_learning", # Updated genre
                "kind": f"wocoe_topglobal_corrupted_{int(args.svd_threshold*100)}pct", # Updated kind
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "subsample": args.subsample,
                "merge_task_vectors_svd_threshold": args.svd_threshold, # Threshold used during merge_task_vectors merging
                "merge_task_vectors_keep_top_values": args.keep_top_values, # Mode for merge_task_vectors merging
                "learnable_sv_percentage": args.svd_threshold * 100, # Percentage of SVs learned
                # Updated total_learnable_values and approach description
                "total_learnable_singular_values": image_encoder.total_learnable_sv if hasattr(image_encoder, 'total_learnable_sv') else 0, 
                "learning_approach": "merge_task_vectors_selected_svd_components_with_learnable_S", # Updated approach
                
                # System information
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                # Runtime info
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                # Job details
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                # Training details
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,  # Using torch.autocast
                
                # Model details
                "port_selection_method": "auto_find_available",
                "sorting_descending": args.sorting_descending,
            }
        )
        
        # Print WandB URL
        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
            print("-" * 50 + "\n")
        
    
        # Log initial fixed singular values to wandb
        initial_selected_sv_log = {}
        for param_name in image_encoder.param_names:
            svd_key_safe_name = param_name.replace('.', '_')
            # New buffer name for initial S values selected by merge_task_vectors
            initial_s_buffer_name = f'initial_selected_S_{svd_key_safe_name}'
            if hasattr(image_encoder, initial_s_buffer_name):
                initial_s_values = getattr(image_encoder, initial_s_buffer_name)
                if initial_s_values.numel() > 0:
                    # Log first few and basic stats for initial selected S values
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_count"] = initial_s_values.numel()
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_mean"] = initial_s_values.mean().item()
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_max"] = initial_s_values.max().item()
                    for i, val in enumerate(initial_s_values[:5].tolist()): # Log first 5 values
                        initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_idx{i}"] = val
        if initial_selected_sv_log:
            wandb.log(initial_selected_sv_log)
    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])
    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=2,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    # Printing loss between four and ten times an epoch
    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    # Distribute the data and model across the GPUs.
    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=True,
        output_device=args.rank,
    )

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    # Do not use warm up
    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    # Get SLURM job ID from environment variables (use 'unknown' if not available)
    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    
    # Create a suffix for filenames that includes SVD thresholding info if enabled
    zeroing_mode = "keepTop" if args.keep_top_values else "zeroTop"
    svd_suffix = f"-{zeroing_mode}{int(args.svd_threshold * 100)}pct"
    
    # Define paths for saving results, including subsample info and learn percentage
    learn_sv_suffix = f"-learnTop{int(args.svd_threshold * 100)}pct"
    log_path = os.path.join('results', f"slurm-{slurm_job_id}-{target_dataset}_isoc_learned_svd{learn_sv_suffix}{svd_suffix}_{subsample_str}.json")
    print("log_path", log_path)

    scaler = GradScaler()
    if is_main_process():
        print(f"=> Zero-shot accuracy on {target_dataset}:\t{100*args.zs_acc[target_dataset]:.2f}%.")
        # Log zero-shot accuracy to wandb
        wandb.log({
            "target_dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
            "zero_shot_accuracy": 100 * args.zs_acc[target_dataset]
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    # Calculate and log number of trainable parameters more accurately
    num_trainable_sv = 0
    if hasattr(image_encoder, 'total_learnable_sv'):
        num_trainable_sv = image_encoder.total_learnable_sv
        
    num_trainable_params = num_trainable_sv

    # Total params in the DDP wrapped model (includes everything)
    total_params_in_ddp_model = sum(p.numel() for p in ddp_model.parameters())
    
    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters in DDP model (includes frozen base + learnable): {total_params_in_ddp_model:,}")
        print(f"Trainable singular values: {num_trainable_sv:,}")
        print(f"Total trainable parameters: {num_trainable_params:,}")

        original_model_params_count = sum(p.numel() for p in image_encoder.params) # These are the frozen functional params
        
        percentage_trainable_vs_base = 0
        if original_model_params_count > 0 : # Avoid division by zero
             pass

        if total_params_in_ddp_model > 0:
            print(f"Percentage trainable of DDP model: {100 * num_trainable_params / total_params_in_ddp_model:.4f}%\n")

        wandb.config.update({
            "total_parameters_ddp_model": total_params_in_ddp_model,
            "trainable_singular_values": num_trainable_sv,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_percentage_of_ddp": (100 * num_trainable_params / total_params_in_ddp_model) if total_params_in_ddp_model > 0 else 0
        })
        
        # Also log as metrics for tracking over time
        wandb.log({
            "total_parameters_ddp_model": total_params_in_ddp_model,
            "trainable_singular_values": num_trainable_sv,
            "trainable_parameters_total": num_trainable_params, # Renamed for clarity
            "trainable_parameters_percentage_of_ddp": (100 * num_trainable_params / total_params_in_ddp_model) if total_params_in_ddp_model > 0 else 0
        })

    for epoch in range(args.epochs):
        # Track epoch metrics
        epoch_loss = 0.0
        epoch_steps = 0
        epoch_total_correct = 0
        epoch_total_samples = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)
                
                # Calculate training accuracy for the current batch
                preds = logits.argmax(dim=-1)
                correct_in_batch = (preds == labels).sum().item()
                total_in_batch = labels.size(0)
                batch_training_accuracy = correct_in_batch / total_in_batch

                # Apply regularization on singular values
                sv_reg = 0.0
                for param_name, param in ddp_model.module.image_encoder.learnable_s_values.items():
                    sv_reg += torch.norm(param, p=2)
                sv_reg = args.lp_reg * sv_reg if args.lp_reg is not None else 0
                loss = loss + sv_reg / args.num_grad_accumulation
                
                loss = loss / args.num_grad_accumulation
                
                # Track loss for epoch average
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            # Add memory cleanup
            torch.cuda.empty_cache()

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"           # noqa: E501
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",   # noqa: E501
                    flush=True,
                )
                print(f"Batch Training Accuracy: {batch_training_accuracy*100:.2f}%") # Print batch accuracy
                
                # Log batch metrics to wandb
                batch_log = {
                    "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "singular_value_reg": sv_reg.item() if hasattr(sv_reg, 'item') else sv_reg,
                    "batch_training_accuracy": batch_training_accuracy * 100,
                }
                
                wandb.log(batch_log)

            if 'process_epoch_correct' not in locals(): # Initialize if first batch of epoch for this process
                process_epoch_correct = 0
                process_epoch_samples = 0
            process_epoch_correct += correct_in_batch
            process_epoch_samples += total_in_batch

        # Evaluate after each epoch
        if is_main_process():
            tensor_process_epoch_correct = torch.tensor(process_epoch_correct, dtype=torch.float32, device=args.rank)
            tensor_process_epoch_samples = torch.tensor(process_epoch_samples, dtype=torch.float32, device=args.rank)

            torch.distributed.all_reduce(tensor_process_epoch_correct, op=torch.distributed.ReduceOp.SUM)
            torch.distributed.all_reduce(tensor_process_epoch_samples, op=torch.distributed.ReduceOp.SUM)

            global_epoch_correct = tensor_process_epoch_correct.item()
            global_epoch_samples = tensor_process_epoch_samples.item()

            epoch_training_accuracy = 0
            if global_epoch_samples > 0:
                epoch_training_accuracy = (global_epoch_correct / global_epoch_samples) * 100
            else:
                print("Warning: global_epoch_samples is 0. Cannot compute epoch training accuracy.")

            # Reset for next epoch for this process
            process_epoch_correct = 0
            process_epoch_samples = 0

            # Log average epoch loss
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
                "epoch_training_accuracy": epoch_training_accuracy, # Log epoch training accuracy
            })
            
            print(f"Epoch {epoch}: Average Loss: {avg_epoch_loss:.4f}, Epoch Training Accuracy: {epoch_training_accuracy:.2f}%")

            # Evaluate on validation set
            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            # Log validation metrics
            epoch_log = {
                "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            # Log learnable singular values
            for key, param in image_encoder.learnable_s_values.items():
                # For each parameter, log the top 5 singular values (or all if fewer than 5)
                values = param.detach().cpu().tolist()
                for i, val in enumerate(values[:5]): # key is already param_name.replace('.', '_')
                    epoch_log[f"sv_{key}_idx{i}_value"] = val 
                
                # Also log mean and max
                epoch_log[f"sv_{key}_mean"] = param.detach().mean().item()
                epoch_log[f"sv_{key}_max"] = param.detach().max().item()

            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%")

    if is_main_process():
        # --- 1. Evaluate the model at its final training state (last epoch) ---
        print("\nEvaluating model at the end of training (last epoch state)...")
        image_encoder_last_state = ddp_model.module.image_encoder
        
        # Debug: Log sums of learnable parameters for the last state
        last_s_values_sum = 0.0
        if hasattr(image_encoder_last_state, 'learnable_s_values'):
            for param in image_encoder_last_state.learnable_s_values.values():
                last_s_values_sum += param.detach().sum().item()
        
        print(f"[DEBUG LAST STATE] Sum of learnable_s_values: {last_s_values_sum}")
        wandb.log({
            "debug_last_state_s_values_sum": last_s_values_sum
        })

        test_metrics_last_model = eval_single_dataset(image_encoder_last_state, target_dataset.replace("Val", ""), args)
        final_test_acc_last_model = test_metrics_last_model["top1"]
        
        print(f"Final test accuracy (last model state) on {target_dataset.replace('Val', '')}: {100 * final_test_acc_last_model:.2f}%")
        wandb.log({
            "dataset": str(target_dataset.replace('Val', '')), 
            "final_test_accuracy_last": 100 * final_test_acc_last_model,
        })

        target_dataset_test = target_dataset.replace("Val", "")
        
        comp_acc[target_dataset_test] = final_test_acc_last_model # Store the one from the last model state
        
        # Log final results (only last_model)
        wandb.log({
            "dataset": str(target_dataset_test), 
            "final_test_accuracy": 100 * final_test_acc_last_model, # On Test set from last model
        })
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        # Create a new path for the test accuracy file with SLURM job ID
        # Use the same suffix generation as model_path/log_path
        test_acc_path = os.path.join('results',
                                     f"slurm-{slurm_job_id}-{target_dataset_test}_isoc_learned_svd{learn_sv_suffix}{svd_suffix}_{subsample_str}_accuracy.json")
        
        # Prepare test accuracy data with additional information
        test_acc_data = {
            "target_dataset": str(target_dataset_test),  # Convert to string to ensure JSON serializable
            "number_of_used_datasets": len(args.datasets),  # Number of datasets used in this run
            "datasets_names": args.datasets,  # List of dataset names used
            "test_accuracy_last_model": float(final_test_acc_last_model), # Final test accuracy (from last model state)
            "test_metrics_last_model": test_metrics_last_model, # Metrics from last model state
            "merge_task_vectors_svd_threshold": args.svd_threshold if hasattr(args, 'svd_threshold') else 0.0,
            "merge_task_vectors_keep_top_values": args.keep_top_values if hasattr(args, 'keep_top_values') else False,
            "merge_task_vectors_zeroing_mode": "keep_top" if (hasattr(args, 'keep_top_values') and args.keep_top_values) else "zero_top",
            "merge_task_vectors_zeroing_percentage": int(args.svd_threshold * 100) if hasattr(args, 'svd_threshold') else 0,
            "learnable_sv_percentage": args.svd_threshold * 100, # Percentage learned (actually, % selected by merge_task_vectors, all of which are learned)
            # The final_sv_values dictionary is added below
            # "learned_singular_values": {k: v.tolist() for k, v in image_encoder.learnable_s_values.items()},
            # Add initial selected singular values to the results file
            "initial_selected_singular_values_from_isoc": {}
        }
        
        # Add learned singular values with clear indices to the final results
        final_sv_values = {}
        for key, param in image_encoder.learnable_s_values.items():
            values = param.detach().cpu().tolist()
            # Store each singular value with its index for clearer data analysis
            sv_with_indices = {f"idx{i}": val for i, val in enumerate(values)}
            final_sv_values[key] = sv_with_indices 
            
            # Also log to wandb with clear indices
            for i, val in enumerate(values): # key is param_name.replace('.', '_')
                wandb.log({f"final_sv_{key}_idx{i}_value": val})
                
        # Add to test_acc_data
        test_acc_data["learned_singular_values"] = final_sv_values
        
        # Add initial fixed singular values to the final results file
        initial_selected_sv_data_for_json = {}
        for param_name in image_encoder.param_names:
            svd_key_safe_name = param_name.replace('.', '_')
            initial_s_buffer_name = f'initial_selected_S_{svd_key_safe_name}'
            if hasattr(image_encoder, initial_s_buffer_name):
                initial_s_values = getattr(image_encoder, initial_s_buffer_name)
                if initial_s_values.numel() > 0:
                    selected_sv_with_indices = {f"idx{i}": val for i, val in enumerate(initial_s_values.tolist())}
                    initial_selected_sv_data_for_json[f"initial_selected_sv_{svd_key_safe_name}"] = selected_sv_with_indices
        test_acc_data["initial_selected_singular_values_from_isoc"] = initial_selected_sv_data_for_json

        # Save or append the test accuracy data to the file
        if os.path.exists(test_acc_path):
            # Load existing data
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            # Convert to list if single dict
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            # Append new data
            existing_data.append(test_acc_data)
            
            # Save updated data
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            # Create new file with initial data
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}")
     
    # Calculate experiment duration
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    # Format duration as hours:minutes:seconds
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}")
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)")
        
        # Log timing information to wandb
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    # Finish the wandb run before cleaning up distributed process
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    # Define target datasets
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    # Parse command line arguments
    args = parse_arguments()
    
    # Add default attributes
    
    # Print experiment tracking info
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("to ensure that experiment code states are tracked.")
    print("=" * 80)
    
    # Set default values
    args.datasets = target_datasets
    args.lr = 1e-1
    args.epochs = 10
    # We use gradient accumulation to simulate larger batch sizes if the model does not fit in memory.
    
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1
    args.print_every = 10
    
    # Set the seed to 0 for deterministic runs
    args.seed = 0
    args.save = args.save + f"{args.model}"
    
    # Load zero-shot accuracies
    with open(os.path.join(args.save, "zeroshot_accuracies.json"), 'r') as f:
        args.zs_acc = json.load(f)
    

    
    # Set no_use_half=True for SVD operations
    if not args.no_use_half:
        print(f"SVD operations require full precision, automatically enabling --no-use-half=True")
        args.no_use_half = True
    
    # Set SVD-related flags
    args.isoc = True
    args.use_svd = True
    # Removed the default setting for args.fixed_values
    # if not hasattr(args, 'fixed_values'):
    #     args.fixed_values = 3  # Default number of top singular values to learn per source task
    
    # Set SVD thresholding parameters (used for both merge_task_vectors merging and % learned)
    # args.svd_threshold = 0.1 # Keep the top 10% of singular values
    args.keep_top_values = True  # False = zero out top values, True = keep top values
    args.sorting_descending = True

    
    print("\n" + "=" * 80)
    print("SVD LEARNING & THRESHOLDING CONFIGURATION:")
    print(f"- Learning top {args.svd_threshold * 100:.1f}% singular values for each layer")
    print(f"- merge_task_vectors Threshold: {args.svd_threshold * 100:.1f}% of singular values")
    print(f"- merge_task_vectors Mode: {'KEEP top values, ZERO OUT the rest' if args.keep_top_values else 'ZERO OUT top values, KEEP the rest'}")
    print("=" * 80 + "\n")
    
    
    # Launch distributed training
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
