"""Given an objective, learn the coefficients and singular values 
on SVD components of merged task vectors for a dataset.

Features:
- Uses merge_task_vectors to precompute and select SVD components (U, S, Vh)
- Learns selected singular values and layer-wise coefficients
- Maintains original parameter dtype and shape information
"""

import os
import time
import json
import torch
import torchvision
import sys
import argparse

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from PIL import Image as PILImage

from src.corruptions import (
    gaussian_noise, shot_noise, impulse_noise, speckle_noise,
    gaussian_blur, defocus_blur, zoom_blur,
    contrast, brightness, saturate, jpeg_compression, pixelate
)

from tqdm import tqdm

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel
from src.composition import TopValuesTaskBasedSVDWeightedImageEncoder
from src.composition import MergedTaskVectorImageEncoder
from torch import nn

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def gaussian_noise_pil(img, severity=1):
    """Wrapper for gaussian_noise that works with PIL Images."""
    img_array = gaussian_noise(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def shot_noise_pil(img, severity=1):
    """Wrapper for shot_noise that works with PIL Images."""
    img_array = shot_noise(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def impulse_noise_pil(img, severity=1):
    """Wrapper for impulse_noise that works with PIL Images."""
    return impulse_noise(img, severity)

def speckle_noise_pil(img, severity=1):
    """Wrapper for speckle_noise that works with PIL Images."""
    img_array = speckle_noise(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def gaussian_blur_pil(img, severity=1):
    """Wrapper for gaussian_blur that works with PIL Images."""
    img_array = gaussian_blur(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def defocus_blur_pil(img, severity=1):
    """Wrapper for defocus_blur that works with PIL Images."""
    img_array = defocus_blur(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def zoom_blur_pil(img, severity=1):
    """Wrapper for zoom_blur that works with PIL Images."""
    img_array = zoom_blur(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def contrast_pil(img, severity=1):
    """Wrapper for contrast that works with PIL Images."""
    result = contrast(img, severity)
    return result

def brightness_pil(img, severity=1):
    """Wrapper for brightness that works with PIL Images."""
    result = brightness(img, severity)
    return result

def saturate_pil(img, severity=1):
    """Wrapper for saturate that works with PIL Images."""
    result = saturate(img, severity)
    return result

def jpeg_compression_pil(img, severity=1):
    """Wrapper for jpeg_compression that works with PIL Images."""
    return jpeg_compression(img, severity)

def pixelate_pil(img, severity=1):
    """Wrapper for pixelate that works with PIL Images."""
    return pixelate(img, severity)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
     
    if x is None or p is None:
        return 0
     
    if not x.requires_grad:
        return 0
    return gamma * torch.norm(x, p=p, dim=0).mean()

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")

def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    port_list = list(port_list)
    random.shuffle(port_list)
    
    for port in port_list:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

def merge_task_vectors(base_model_params_dict, task_vectors, config):
    """
    Merges task vectors by collecting all SVD components from all tasks for each layer,
    globally sorting them by singular value, and selecting a subset based on a threshold
    relative to a single task's component count.

    For each 2D parameter:
    1. For each task, perform SVD on the task's delta matrix.
    2. Collect all singular components (U, S, Vh) from all tasks into a single list for the layer.
    3. Sort this global list of components based on their singular values.
    4. Select the top K components, where K is determined by `config.svd_threshold`
       applied to the number of components from a *single* task.
    5. The selected U, S, and Vh components are returned directly, without a second SVD.
    Non-2D layers' deltas are averaged.
    
    Returns:
        dict: Layer components with keys:
            - 'U', 'S', 'Vh', 'is_svd' (True) - for SVD layers
            - 'tensor', 'is_svd' (False) - for non-SVD layers
            - Metadata: original_dtype, original_shape, num_selected_components
    """
    if not task_vectors:
        print("Warning: No task vectors provided to merge_task_vectors.")
        return None

    print(f"Computing global SVD component selection for {len(task_vectors)} task vectors...")
    
    with torch.no_grad():
        new_vector = {}
        for task_key in task_vectors[0].vector:
            base_key = task_key
            if task_key.startswith("model."):
                base_key = task_key[len("model."):]

            if base_key not in base_model_params_dict:
                print(f"Warning: Base key '{base_key}' (from task key '{task_key}') not in base_model_params_dict. Skipping.")
                continue

            current_device = base_model_params_dict[base_key].device
            original_dtype = task_vectors[0].vector[task_key].dtype
            original_shape = task_vectors[0].vector[task_key].shape

            SKIP_PARAMS = [
                "positional_embedding", "visual.positional_embedding",
                "visual.proj", "token_embedding", "visual.proj", "text_projection", "token_embedding.weight",
                "model.positional_embedding", "model.visual.positional_embedding", "model.token_embedding.weight",
                "model.visual.proj", "model.token_embedding", "model.visual.proj", "model.text_projection"
            ]

            if len(original_shape) != 2 or "text_projection" in task_key or task_key in SKIP_PARAMS:
                tensor_val = None
                if len(task_vectors) == 1:
                    tensor_val = task_vectors[0].vector[task_key].to(current_device)
                else:
                    tvs = [tv.vector[task_key].to(current_device) for tv in task_vectors]
                    tensor_val = sum(tvs) / len(tvs)
                new_vector[task_key] = {
                    'tensor': tensor_val, 'is_svd': False,
                    'original_dtype': original_dtype, 'original_shape': original_shape
                }
                continue

            try:
                all_components = []
                for task_idx, task_vector in enumerate(task_vectors):
                    delta_matrix = task_vector.vector[task_key].to(device=current_device, dtype=torch.float32)
                    U_task, S_task, Vh_task = torch.linalg.svd(delta_matrix, full_matrices=False)
                    
                    for i in range(S_task.shape[0]):
                        all_components.append({
                            "s_value": S_task[i],
                            "u_vector": U_task[:, i],
                            "vh_vector": Vh_task[i, :],
                            "source_task_idx": task_idx
                        })
                
                if not all_components:
                    print(f"Warning: No SVD components generated for key {task_key}. Skipping.")
                    continue

                all_components.sort(key=lambda x: x["s_value"], reverse=config.sorting_descending)
                
                num_to_keep = 76
                num_to_keep = min(num_to_keep, len(all_components))

                if num_to_keep == 0 and config.svd_threshold > 0:
                    print(f"Warning for {task_key}: num_to_keep is 0 with svd_threshold {config.svd_threshold}. Model will handle 0 components.")

                selected_components = all_components[:num_to_keep]
                kept_component_count = len(selected_components)

                if kept_component_count == 0:
                    print(f"Warning: No singular value components selected for key {task_key}. Creating zero SVD components.")
                    U_final = torch.empty((original_shape[0], 0), dtype=torch.float32, device=current_device)
                    S_final = torch.empty((0,), dtype=torch.float32, device=current_device)
                    Vh_final = torch.empty((0, original_shape[1]), dtype=torch.float32, device=current_device)
                else:
                    S_final = torch.stack([comp["s_value"] for comp in selected_components])
                    U_final = torch.stack([comp["u_vector"] for comp in selected_components], dim=1)
                    Vh_final = torch.stack([comp["vh_vector"] for comp in selected_components], dim=0)

                new_vector[task_key] = {
                    'U': U_final.to(original_dtype),
                    'S': S_final.to(original_dtype),
                    'Vh': Vh_final.to(original_dtype),
                    'is_svd': True,
                    'original_dtype': original_dtype,
                    'original_shape': original_shape,
                    'num_selected_components': kept_component_count
                }
                
                print(f"Global SVD components extracted for key '{task_key}', original shape: {original_shape}")
                print(f"  Total components from all tasks: {len(all_components)}")
                print(f"  Selected {kept_component_count} components (fixed selection).")
                print(f"  Final U shape: {U_final.shape}, Final S shape: {S_final.shape}, Final Vh shape: {Vh_final.shape}")

            except Exception as e:
                print(f"Error: Global SVD component selection failed for key '{task_key}' with error: {str(e)}")
                import traceback
                traceback.print_exc()
                print(f"Skipping key {task_key} due to SVD error.")
                continue

    return new_vector

def create_task_vector_from_merged(merged_vector, task_vectors, args):
    """(Note: Primarily for compatibility - expects reconstructed tensors)
    Creates TaskVector from traditional merged vector format.
    May require reconstruction from components for new merge_task_vectors outputs.
    """
    if args.finetuning_mode == "linear":
        return LinearizedTaskVector(vector=merged_vector, use_half=not args.no_use_half)
    else:
        return NonLinearTaskVector(vector=merged_vector, use_half=not args.no_use_half)

class LearnableSingularValuesMergedEncoder(nn.Module):
    """Learns singular values for precomputed SVD components.
    
    Uses components from merge_task_vectors instead of performing new SVD:
    - Stores U/Vh as buffers
    - Initializes selected S values to 0.0 and makes them learnable
    - Handles non-SVD layers through direct delta tensors
    
    Args:
        merged_vector_components: Output from merge_task_vectors containing:
            - SVD components for 2D layers
            - Direct tensors for other layers
    """
    def __init__(self, model, merged_vector_components, args):
        super().__init__()
        
        self.model = model
        self.args = args
        
        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir
        
        from functorch import make_functional_with_buffers
        func, params_from_functional, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = nn.ParameterList(params_from_functional)
        for p in self.params:
            p.requires_grad = False
            
        self.param_names = [name for name, _ in model.named_parameters()]

        self.svd_components_info = {}
        self.learnable_s_values = nn.ParameterDict()
        self.direct_deltas = {}
        self.total_learnable_sv = 0
        
        print("\n--- Initializing LearnableSingularValuesMergedEncoder ---")
        print("Processing parameters based on pre-computed SVD components or direct deltas:")
        
        for param_idx, param_name in enumerate(self.param_names):
            if param_name in merged_vector_components:
                layer_data = merged_vector_components[param_name]
                svd_key_safe_name = param_name.replace('.', '_')
                
                original_dtype = layer_data['original_dtype']
                original_shape = layer_data['original_shape']

                if layer_data['is_svd']:
                    U_from_isoc = layer_data['U']
                    S_from_isoc = layer_data['S']
                    Vh_from_isoc = layer_data['Vh']
                    num_components_from_isoc = layer_data['num_selected_components']

                    if num_components_from_isoc > 0:
                        print(f"  SVD layer {param_name}: Reconstructing delta from {num_components_from_isoc} components...")
                        reconstructed_delta = U_from_isoc.to(torch.float32) @ torch.diag_embed(S_from_isoc.to(torch.float32)) @ Vh_from_isoc.to(torch.float32)
                        
                        print(f"  SVD layer {param_name}: Performing second SVD on reconstructed delta...")
                        U_new, S_new, Vh_new = torch.linalg.svd(reconstructed_delta, full_matrices=False)

                        num_total_new_components = S_new.shape[0]
                        num_to_learn = int(self.args.svd_threshold * num_total_new_components)

                        S_initial_learnable = S_new[:num_to_learn]
                        S_initial_frozen = S_new[num_to_learn:]
                        
                        self.register_buffer(f'U_{svd_key_safe_name}', U_new)
                        self.register_buffer(f'Vh_{svd_key_safe_name}', Vh_new)
                        
                        self.register_buffer(f'initial_selected_S_{svd_key_safe_name}', S_initial_learnable)
                        self.register_buffer(f'frozen_S_{svd_key_safe_name}', S_initial_frozen)
                        
                        learnable_S_for_layer = torch.full_like(S_initial_learnable, 0.0, dtype=torch.float32)
                        self.learnable_s_values[svd_key_safe_name] = nn.Parameter(learnable_S_for_layer)
                        
                        self.total_learnable_sv += num_to_learn
                        
                        self.svd_components_info[param_name] = {
                            'original_dtype': original_dtype,
                            'original_shape': original_shape,
                            'num_learnable': num_to_learn,
                            'num_frozen': len(S_initial_frozen),
                            'num_total_components': num_total_new_components,
                        }
                        print(f"  SVD layer {param_name} (Shape {original_shape}): "
                              f"Re-SVD created {num_total_new_components} new components. "
                              f"Learning top {num_to_learn} ({self.args.svd_threshold*100:.1f}%) singular values. "
                              f"Freezing remaining {len(S_initial_frozen)}.")
                    else:
                        print(f"  SVD layer {param_name} (Shape {original_shape}, Dtype {original_dtype}): "
                              f"No SVD components selected by merge_task_vectors. Delta will be zero.")
                        self.svd_components_info[param_name] = { 
                            'original_dtype': original_dtype,
                            'original_shape': original_shape,
                            'num_learnable': 0,
                            'num_frozen': 0,
                            'num_total_components': 0,
                        }
                        device = self.params[param_idx].device
                        self.register_buffer(f'U_{svd_key_safe_name}', torch.empty((original_shape[0], 0), dtype=torch.float32, device=device))
                        self.register_buffer(f'Vh_{svd_key_safe_name}', torch.empty((0, original_shape[1]), dtype=torch.float32, device=device))
                        self.register_buffer(f'frozen_S_{svd_key_safe_name}', torch.empty((0,), dtype=torch.float32, device=device))
                else:
                    direct_delta_tensor = layer_data['tensor']
                    self.register_buffer(f'direct_delta_{svd_key_safe_name}', direct_delta_tensor)
                    self.direct_deltas[param_name] = {
                         'original_dtype': original_dtype,
                         'key_for_buffer': f'direct_delta_{svd_key_safe_name}'
                    }
                    print(f"  Non-SVD layer {param_name} (Shape {direct_delta_tensor.shape if isinstance(direct_delta_tensor, torch.Tensor) else 'N/A'}, Dtype {direct_delta_tensor.dtype if isinstance(direct_delta_tensor, torch.Tensor) else 'N/A'}): Using direct delta.")
            else:
                print(f"  No SVD components or direct delta in merged_vector_components for {param_name}. Update for this layer will be scaled by multiplier but start at zero.")
        
        print(f"Total learnable singular values across all layers: {self.total_learnable_sv}")
        print("--- Initialization Complete --- \n")
        

    def _apply(self, fn):
        """Override method to relocate buffer list"""
        new_self = super()._apply(fn=fn)
        
        if hasattr(new_self, 'buffer') and new_self.buffer is not None:
            new_self.buffer = tuple(fn(b) for b in new_self.buffer)

        if hasattr(new_self, 'params') and isinstance(new_self.params, nn.ParameterList):
            for i in range(len(new_self.params)):
                if new_self.params[i] is not None:
                     new_self.params[i].data = fn(new_self.params[i].data)


        return new_self
    
    def train(self, mode=True):
        super().train(mode)

    def forward(self, x):
        """Forward pass applying the merged vector with learnable singular values to the model parameters."""
        final_model_params = []
        for param_idx, param_name in enumerate(self.param_names):
            base_param = self.params[param_idx]
            
            actual_delta_c: torch.Tensor
            svd_key_safe_name = param_name.replace('.', '_')

            if param_name in self.svd_components_info:
                svd_info = self.svd_components_info[param_name]
                num_total_components = svd_info.get('num_total_components', 0)
                original_delta_dtype = svd_info['original_dtype']

                if num_total_components > 0:
                    U = getattr(self, f'U_{svd_key_safe_name}')
                    Vh = getattr(self, f'Vh_{svd_key_safe_name}')
                    
                    learnable_S_values = self.learnable_s_values.get(svd_key_safe_name)
                    frozen_S_values = getattr(self, f'frozen_S_{svd_key_safe_name}', None)
                    
                    s_parts = []
                    if learnable_S_values is not None and learnable_S_values.numel() > 0:
                        s_parts.append(learnable_S_values)
                    if frozen_S_values is not None and frozen_S_values.numel() > 0:
                        s_parts.append(frozen_S_values)
                    
                    if not s_parts:
                        reconstructed_delta = torch.zeros(svd_info['original_shape'], device=base_param.device)
                    else:
                        full_S_vector = torch.cat(s_parts, dim=0)

                        U_f32 = U.to(torch.float32)
                        Vh_f32 = Vh.to(torch.float32)
                        full_S_f32 = full_S_vector.to(torch.float32)
                        
                        reconstructed_delta = U_f32 @ torch.diag_embed(full_S_f32) @ Vh_f32
                    
                    actual_delta_c = reconstructed_delta.to(original_delta_dtype)
                else:
                    actual_delta_c = torch.zeros_like(base_param)
            
            elif f'direct_delta_{svd_key_safe_name}' in self._buffers:
                direct_delta_tensor = getattr(self, f'direct_delta_{svd_key_safe_name}')
                actual_delta_c = direct_delta_tensor
            else:
                actual_delta_c = torch.zeros_like(base_param)

            final_model_params.append(base_param + actual_delta_c)
        
        return self.func(final_model_params, self.buffer, x)


class CorruptionTransform:
    """Picklable corruption transform class for multiprocessing compatibility."""
    
    def __init__(self, corruption_func, severity):
        self.corruption_func = corruption_func
        self.severity = severity
        self.corruption_name = getattr(corruption_func, '__name__', str(corruption_func))
    
    def __call__(self, img):
        try:
            return self.corruption_func(img, severity=self.severity)
        except Exception as e:
            print(f"Error in corruption {self.corruption_name} with severity {self.severity}: {e}", flush=True)
            return img
    
    def __repr__(self):
        return f"CorruptionTransform({self.corruption_name}, severity={self.severity})"

def save_first_corrupted_image(image_tensor, dataset_name, corruption_func, severity):
    """
    Save the first image from the first batch to disk for visualization.
    
    Args:
        image_tensor: Tensor of shape (C, H, W) representing the corrupted image
        dataset_name: Name of the dataset
        corruption_func: The corruption function used
        severity: Severity level of the corruption
    """
    try:
        output_dir = "/check/corruptions"
        os.makedirs(output_dir, exist_ok=True)
        
        corruption_name = getattr(corruption_func, '__name__', str(corruption_func))
        if corruption_name.endswith('_pil'):
            corruption_name = corruption_name[:-4]
        
        image_np = image_tensor.cpu().clone()
        
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        
        image_np = image_np * std + mean
        
        image_np = torch.clamp(image_np, 0, 1)
        
        image_np = image_np.permute(1, 2, 0).numpy()
        image_np = (image_np * 255).astype(np.uint8)
        
        pil_image = PILImage.fromarray(image_np)
        
        filename = f"{dataset_name}_{corruption_name}_severity{severity}_example.png"
        filepath = os.path.join(output_dir, filename)
        
        pil_image.save(filepath)
        
        print(f"[DEBUG] Saved corrupted image example: {filepath}", flush=True)
        
    except Exception as e:
        print(f"[WARNING] Failed to save corrupted image: {e}", flush=True)


def evaluate_with_corruption(image_encoder, dataset_name, args, corruption_func, severity):
    """
    Evaluates the model on a dataset with a given corruption function and severity.
    """
    if not is_main_process():
        return None

    print(f"[DEBUG] Starting evaluation with {getattr(corruption_func, '__name__', str(corruption_func))} severity {severity}", flush=True)
    
    image_encoder.eval()

    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)
    model = model.cuda()
    model.eval()

    corruption_transform = CorruptionTransform(corruption_func, severity)
    print(f"[DEBUG] Created corruption transform: {corruption_transform}", flush=True)

    original_preprocess = model.val_preprocess
    
    if not hasattr(original_preprocess, 'transforms'):
         print(f"Error: val_preprocess for {dataset_name} is not a Compose object.", flush=True)
         return {"top1": 0.0}

    new_transforms_list = []
    
    to_tensor_idx = -1
    for i, t in enumerate(original_preprocess.transforms):
        if isinstance(t, torchvision.transforms.ToTensor):
            to_tensor_idx = i
            break
            
    if to_tensor_idx == -1:
        print(f"Warning: ToTensor not found. Adding corruption to the beginning of transforms for {dataset_name}.", flush=True)
        new_transforms_list.append(corruption_transform)
        new_transforms_list.extend(original_preprocess.transforms)
    else:
        temp_list = list(original_preprocess.transforms)
        temp_list.insert(to_tensor_idx, corruption_transform)
        new_transforms_list = temp_list

    corrupted_preprocess = torchvision.transforms.Compose(new_transforms_list)
    print(f"[DEBUG] Created corrupted preprocessing pipeline with {len(new_transforms_list)} transforms", flush=True)

    print(f"[DEBUG] Getting dataset {dataset_name} with corrupted preprocessing...", flush=True)
    dataset = get_dataset(
        dataset_name,
        corrupted_preprocess,
        location=args.data_location,
        batch_size=args.batch_size
    )
    
    print(f"[DEBUG] Creating dataloader with num_workers=0 to avoid multiprocessing issues...", flush=True)
    
    temp_args = argparse.Namespace(**vars(args))
    temp_args.num_workers = 4
    
    dataloader = get_dataloader(
        dataset, is_train=False, args=temp_args, image_encoder=None
    )
    
    if hasattr(dataloader, 'num_workers'):
        dataloader.num_workers = 4
    
    print(f"[DEBUG] Starting evaluation loop with {len(dataloader)} batches...", flush=True)
    
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Evaluating {getattr(corruption_func, '__name__', 'corruption')} sev{severity}", 
                   leave=False, ncols=100)
        
        for batch_idx, batch in enumerate(pbar):
            batch = maybe_dictionarize(batch)
            images, labels = batch["images"].cuda(), batch["labels"].cuda()
            
            logits = model(images)
            preds = logits.argmax(dim=-1)
            batch_correct = (preds == labels).sum().item()
            total_correct += batch_correct
            total_samples += labels.size(0)
            
            current_acc = total_correct / total_samples if total_samples > 0 else 0.0
            pbar.set_postfix({'acc': f'{current_acc:.3f}', 'samples': total_samples})

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"[DEBUG] Completed evaluation: {total_correct}/{total_samples} = {accuracy:.4f}", flush=True)
    
    return {"top1": accuracy}

def main(rank, args):
    args.rank = rank
    
    distributed_initialized = False
    
    try:
        available_ports = list(range(29520, 29590))
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        print("no_use_half", args.no_use_half)
        if hasattr(args, 'isoc') and args.isoc and not args.no_use_half:
            if is_main_process():
                print("\n")
                print("*" * 80)
                print("WARNING: ISO-C requires full precision for SVD operations.")
                print("Automatically switching to full precision mode (--no-use-half).")
                print("*" * 80)
                print("\n")
            args.no_use_half = True
        
        if args.seed is not None:
            set_seed(args.seed)
        
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]

        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()
        
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx >= args.end_index:
                print(f"Ending iteration {idx+1} because end-index is {args.end_index}")
                break

            try:
                args.datasets = original_datasets[:idx+1]

                if is_main_process():
                    precision_mode = "half precision (float16)" if not args.no_use_half else "full precision (float32)"
                    print(f"Creating task vectors using {precision_mode}")
                    print(f"SVD learning mode: learning top {args.svd_threshold*100:.1f}% singular values for each layer")
                    if args.svd_threshold > 0 and args.no_use_half:
                        print(f"Using SVD thresholding (in merge_task_vectors) with threshold {args.svd_threshold}")
                        if args.keep_top_values:
                            print(f"Mode: KEEP top {args.svd_threshold*100}% singular values, ZERO OUT the rest")
                        else:
                            print(f"Mode: ZERO OUT top {args.svd_threshold*100}% singular values, KEEP the rest")
                    elif args.svd_threshold > 0 and not args.no_use_half:
                        print(f"WARNING: SVD thresholding requires full precision. Use --no-use-half flag for SVD operations.")

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning SVD-merged task vector singular values on {args.target_dataset} with {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                if "CUDA out of memory" in str(e):
                    print("CUDA out of memory error detected - stopping the entire job!")
                    if distributed_initialized:
                        cleanup_ddp()
                    import sys
                    sys.exit(1)
            
    finally:
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    args.datasets = original_datasets


def train(task_vectors, args):

    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    comp_acc = {}
    
    if args.seed is not None:
        set_seed(args.seed + args.rank)

    target_dataset = args.target_dataset
    
    task_vectors_list = [v for k, v in task_vectors.items() if target_dataset.replace("Val", "") != k]
    num_source_tasks = len(task_vectors_list)

    if args.finetuning_mode == "linear":
        if is_main_process():
            print("Error: Linear mode with singular value learning not implemented.")
            if wandb.run is not None:
                wandb.finish(quiet=True)
        return
    else:
        base_image_encoder = ImageEncoder(args)

    base_model_params_dict = {
        name: tensor.clone() 
        for name, tensor in base_image_encoder.model.state_dict().items()
    }
    
    if is_main_process():
        print("\n" + "="*50)
        print("ISO-C LEARN MODE: Merging task vectors with SVD-based component selection")
        print(f"SVD performed on (base_param + task_delta) for each task.")
        print(f"Learning top {args.svd_threshold*100:.1f}% singular values for each merged layer")
        print("="*50 + "\n")
    
    merged_vector_components = merge_task_vectors(base_model_params_dict, task_vectors_list, args)
    
    if merged_vector_components is None:
        if is_main_process():
            print("Error: Failed to merge task vectors with SVD.")
            if wandb.run is not None:
                wandb.finish(quiet=True)
        return
    
    image_encoder = LearnableSingularValuesMergedEncoder(
        base_image_encoder, 
        merged_vector_components,
        args
    )

    if is_main_process():
        gpu_info = {}
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        api_key = os.environ.get("WANDB_API_KEY", "your-wandb-key")
        wandb.login(key=api_key)
        
        threshold_pct = int(args.svd_threshold * 100)
        run_name = f"{args.model}_{target_dataset}-isoc-learn-svd-top{threshold_pct}pct-e{args.epochs}-nbdts{len(args.datasets)}"
        
        if hasattr(args, 'svd_threshold') and args.svd_threshold > 0:
            merge_task_vectors_threshold_pct = int(args.svd_threshold * 100)
            merge_task_vectors_zeroing_mode = "keepTop" if args.keep_top_values else "zeroTop"
            run_name += f"-merge_task_vectors_{merge_task_vectors_zeroing_mode}{merge_task_vectors_threshold_pct}pct"
        
        run_name += f"-{subsample_str}"
        
        wandb.init(
            project="atlas2",
            entity="example-owner",
            name=run_name,
            config={
                "type": "learn_coef",
                "genre": "isoc_svd_components_learning",
                "kind": f"corruption_robustness_evaluation_ready",
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "subsample": args.subsample,
                "merge_task_vectors_svd_threshold": args.svd_threshold,
                "merge_task_vectors_keep_top_values": args.keep_top_values,
                "learnable_sv_percentage": args.svd_threshold * 100,
                "total_learnable_singular_values": image_encoder.total_learnable_sv if hasattr(image_encoder, 'total_learnable_sv') else 0, 
                "learning_approach": "merge_task_vectors_selected_svd_components_with_learnable_S",
                
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,
                
                "port_selection_method": "auto_find_available",
                "sorting_descending": args.sorting_descending,
            }
        )
        
        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
            print("-" * 50 + "\n")
        
    
        initial_selected_sv_log = {}
        for param_name in image_encoder.param_names:
            svd_key_safe_name = param_name.replace('.', '_')
            initial_s_buffer_name = f'initial_selected_S_{svd_key_safe_name}'
            if hasattr(image_encoder, initial_s_buffer_name):
                initial_s_values = getattr(image_encoder, initial_s_buffer_name)
                if initial_s_values.numel() > 0:
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_count"] = initial_s_values.numel()
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_mean"] = initial_s_values.mean().item()
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_max"] = initial_s_values.max().item()
                    for i, val in enumerate(initial_s_values[:5].tolist()):
                        initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_idx{i}"] = val
        if initial_selected_sv_log:
            wandb.log(initial_selected_sv_log)
    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])
    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=2,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=True,
        output_device=args.rank,
    )

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    
    zeroing_mode = "keepTop" if args.keep_top_values else "zeroTop"
    svd_suffix = f"-{zeroing_mode}{int(args.svd_threshold * 100)}pct"
    
    learn_sv_suffix = f"-learnTop{int(args.svd_threshold * 100)}pct"
    log_path = os.path.join('results', f"slurm-{slurm_job_id}-{target_dataset}_isoc_learned_svd{learn_sv_suffix}{svd_suffix}_{subsample_str}.json")
    print("log_path", log_path)

    scaler = GradScaler()
    if is_main_process():
        wandb.log({
            "target_dataset": str(target_dataset),
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    num_trainable_sv = 0
    if hasattr(image_encoder, 'total_learnable_sv'):
        num_trainable_sv = image_encoder.total_learnable_sv
        
    num_trainable_params = num_trainable_sv
    
    total_params_in_ddp_model = sum(p.numel() for p in ddp_model.parameters())
    
    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters in DDP model (includes frozen base + learnable): {total_params_in_ddp_model:,}")
        print(f"Trainable singular values: {num_trainable_sv:,}")
        print(f"Total trainable parameters: {num_trainable_params:,}")
        
        original_model_params_count = sum(p.numel() for p in image_encoder.params)
        percentage_trainable_vs_base = 0
        if original_model_params_count > 0 :
             pass

        if total_params_in_ddp_model > 0:
            print(f"Percentage trainable of DDP model: {100 * num_trainable_params / total_params_in_ddp_model:.4f}%\n")

        wandb.config.update({
            "total_parameters_ddp_model": total_params_in_ddp_model,
            "trainable_singular_values": num_trainable_sv,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_percentage_of_ddp": (100 * num_trainable_params / total_params_in_ddp_model) if total_params_in_ddp_model > 0 else 0
        })
        
        wandb.log({
            "total_parameters_ddp_model": total_params_in_ddp_model,
            "trainable_singular_values": num_trainable_sv,
            "trainable_parameters_total": num_trainable_params,
            "trainable_parameters_percentage_of_ddp": (100 * num_trainable_params / total_params_in_ddp_model) if total_params_in_ddp_model > 0 else 0
        })

    for epoch in range(args.epochs):
        epoch_loss = 0.0
        epoch_steps = 0
        epoch_total_correct = 0
        epoch_total_samples = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)
                
                preds = logits.argmax(dim=-1)
                correct_in_batch = (preds == labels).sum().item()
                total_in_batch = labels.size(0)
                batch_training_accuracy = correct_in_batch / total_in_batch

                sv_reg = 0.0
                for param_name, param in ddp_model.module.image_encoder.learnable_s_values.items():
                    sv_reg += torch.norm(param, p=2)
                sv_reg = args.lp_reg * sv_reg if args.lp_reg is not None else 0
                loss = loss + sv_reg / args.num_grad_accumulation
                
                loss = loss / args.num_grad_accumulation
                
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            torch.cuda.empty_cache()

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",
                    flush=True,
                )
                print(f"Batch Training Accuracy: {batch_training_accuracy*100:.2f}%")
                
                batch_log = {
                    "dataset": str(target_dataset),
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "singular_value_reg": sv_reg.item() if hasattr(sv_reg, 'item') else sv_reg,
                    "batch_training_accuracy": batch_training_accuracy * 100,
                }
                
                wandb.log(batch_log)

            if 'process_epoch_correct' not in locals():
                process_epoch_correct = 0
                process_epoch_samples = 0
            process_epoch_correct += correct_in_batch
            process_epoch_samples += total_in_batch

        if is_main_process():

            tensor_process_epoch_correct = torch.tensor(process_epoch_correct, dtype=torch.float32, device=args.rank)
            tensor_process_epoch_samples = torch.tensor(process_epoch_samples, dtype=torch.float32, device=args.rank)

            torch.distributed.all_reduce(tensor_process_epoch_correct, op=torch.distributed.ReduceOp.SUM)
            torch.distributed.all_reduce(tensor_process_epoch_samples, op=torch.distributed.ReduceOp.SUM)

            global_epoch_correct = tensor_process_epoch_correct.item()
            global_epoch_samples = tensor_process_epoch_samples.item()

            epoch_training_accuracy = 0
            if global_epoch_samples > 0:
                epoch_training_accuracy = (global_epoch_correct / global_epoch_samples) * 100
            else:
                print("Warning: global_epoch_samples is 0. Cannot compute epoch training accuracy.")

            process_epoch_correct = 0
            process_epoch_samples = 0

            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
                "epoch_training_accuracy": epoch_training_accuracy,
            })
            
            print(f"Epoch {epoch}: Average Loss: {avg_epoch_loss:.4f}, Epoch Training Accuracy: {epoch_training_accuracy:.2f}%", flush=True)

            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            epoch_log = {
                "dataset": str(target_dataset),
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            for key, param in image_encoder.learnable_s_values.items():
                values = param.detach().cpu().tolist()
                for i, val in enumerate(values[:5]):
                    epoch_log[f"sv_{key}_idx{i}_value"] = val 
                
                epoch_log[f"sv_{key}_mean"] = param.detach().mean().item()
                epoch_log[f"sv_{key}_max"] = param.detach().max().item()

            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%", flush=True)

    if is_main_process():
        print("\nEvaluating model at the end of training (last epoch state)...", flush=True)
        image_encoder_last_state = ddp_model.module.image_encoder
        
        last_s_values_sum = 0.0
        if hasattr(image_encoder_last_state, 'learnable_s_values'):
            for param in image_encoder_last_state.learnable_s_values.values():
                last_s_values_sum += param.detach().sum().item()
        
        print(f"[DEBUG LAST STATE] Sum of learnable_s_values: {last_s_values_sum}", flush=True)
        wandb.log({
            "debug_last_state_s_values_sum": last_s_values_sum
        })

        test_metrics_last_model = eval_single_dataset(image_encoder_last_state, target_dataset.replace("Val", ""), args)
        final_test_acc_last_model = test_metrics_last_model["top1"]
        
        print(f"Final test accuracy (last model state) on {target_dataset.replace('Val', '')}: {100 * final_test_acc_last_model:.2f}%", flush=True)
        wandb.log({
            "dataset": str(target_dataset.replace('Val', '')), 
            "final_test_accuracy_last": 100 * final_test_acc_last_model,
        })

        target_dataset_test = target_dataset.replace("Val", "")
        
        comp_acc[target_dataset_test] = final_test_acc_last_model
        
        wandb.log({
            "dataset": str(target_dataset_test), 
            "final_test_accuracy": 100 * final_test_acc_last_model,
        })
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        test_acc_path = os.path.join('results',
                                     f"slurm-{slurm_job_id}-{target_dataset_test}_isoc_learned_svd{learn_sv_suffix}{svd_suffix}_{subsample_str}_accuracy.json")
        
        test_acc_data = {
            "target_dataset": str(target_dataset_test),
            "number_of_used_datasets": len(args.datasets),
            "datasets_names": args.datasets,
            "test_accuracy_last_model": float(final_test_acc_last_model),
            "test_metrics_last_model": test_metrics_last_model,
            "merge_task_vectors_svd_threshold": args.svd_threshold if hasattr(args, 'svd_threshold') else 0.0,
            "merge_task_vectors_keep_top_values": args.keep_top_values if hasattr(args, 'keep_top_values') else False,
            "merge_task_vectors_zeroing_mode": "keep_top" if (hasattr(args, 'keep_top_values') and args.keep_top_values) else "zero_top",
            "merge_task_vectors_zeroing_percentage": int(args.svd_threshold * 100) if hasattr(args, 'svd_threshold') else 0,
            "learnable_sv_percentage": args.svd_threshold * 100,
            "initial_selected_singular_values_from_isoc": {}
        }
        
        final_sv_values = {}
        for key, param in image_encoder.learnable_s_values.items():
            values = param.detach().cpu().tolist()
            sv_with_indices = {f"idx{i}": val for i, val in enumerate(values)}
            final_sv_values[key] = sv_with_indices 
            
            for i, val in enumerate(values):
                wandb.log({f"final_sv_{key}_idx{i}_value": val})
                
        test_acc_data["learned_singular_values"] = final_sv_values
        
        initial_selected_sv_data_for_json = {}
        for param_name in image_encoder.param_names:
            svd_key_safe_name = param_name.replace('.', '_')
            initial_s_buffer_name = f'initial_selected_S_{svd_key_safe_name}'
            if hasattr(image_encoder, initial_s_buffer_name):
                initial_s_values = getattr(image_encoder, initial_s_buffer_name)
                if initial_s_values.numel() > 0:
                    selected_sv_with_indices = {f"idx{i}": val for i, val in enumerate(initial_s_values.tolist())}
                    initial_selected_sv_data_for_json[f"initial_selected_sv_{svd_key_safe_name}"] = selected_sv_with_indices
        test_acc_data["initial_selected_singular_values_from_isoc"] = initial_selected_sv_data_for_json

        if os.path.exists(test_acc_path):
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            existing_data.append(test_acc_data)
            
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}", flush=True)

        print("\n--- Evaluating robustness to image corruptions ---", flush=True)
        
        corruption_functions = [
            ("gaussian_noise", gaussian_noise_pil),
            ("shot_noise", shot_noise_pil),
            ("impulse_noise", impulse_noise_pil),
            ("speckle_noise", speckle_noise_pil),
            ("gaussian_blur", gaussian_blur_pil),
            ("defocus_blur", defocus_blur_pil),
            ("zoom_blur", zoom_blur_pil),
            ("contrast", contrast_pil),
            ("brightness", brightness_pil),
            ("saturate", saturate_pil),
            ("jpeg_compression", jpeg_compression_pil),
            ("pixelate", pixelate_pil),
        ]
        
        corruption_pbar = tqdm(corruption_functions, desc="Corruption Types", ncols=120, position=0)
        
        for corruption_name, corruption_func in corruption_pbar:
            corruption_pbar.set_description(f"Corruption: {corruption_name}")
            print(f"\n-- Evaluating with {corruption_name} corruption --", flush=True)
            
            severity_pbar = tqdm(range(1, 6), desc=f"{corruption_name} severities", 
                               leave=False, ncols=100, position=1)
            
            for severity in severity_pbar:
                severity_pbar.set_description(f"{corruption_name} sev{severity}")
                print(f"--- Evaluating {corruption_name} severity {severity} ---", flush=True)
                
                start_time = time.time()
                
                metrics = evaluate_with_corruption(
                    image_encoder=image_encoder_last_state, 
                    dataset_name=target_dataset.replace("Val", ""), 
                    args=args,
                    corruption_func=corruption_func,
                    severity=severity
                )
                
                acc = metrics['top1']
                eval_time = time.time() - start_time
                
                print(f"Accuracy with {corruption_name} severity {severity}: {100 * acc:.2f}% (took {eval_time:.1f}s)", flush=True)
                
                wandb.log({
                    f"test_accuracy_{corruption_name}_severity_{severity}": 100 * acc,
                    "dataset": target_dataset.replace("Val", ""),
                    "corruption_type": corruption_name,
                    "severity": severity,
                    "evaluation_time_seconds": eval_time
                })
                
                corruption_pbar.set_postfix({'current_acc': f'{100*acc:.1f}%', 'time': f'{eval_time:.1f}s'})
        
        print(f"\n✅ Completed all corruption evaluations!", flush=True)
     
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}", flush=True)
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)", flush=True)
        
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    args = parse_arguments()
    
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("to ensure that experiment code states are tracked.")
    print("=" * 80)
    
    args.datasets = target_datasets
    args.lr = 1e-1
    args.epochs = 10
    
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1
    args.print_every = 10
    
    args.seed = 0
    args.save = args.save + f"{args.model}"
    
    if not args.no_use_half:
        print(f"SVD operations require full precision, automatically enabling --no-use-half=True")
        args.no_use_half = True
    
    args.isoc = True
    args.use_svd = True
    args.keep_top_values = True
    args.sorting_descending = True

    
    print("\n" + "=" * 80)
    print("SVD LEARNING & THRESHOLDING CONFIGURATION:")
    print(f"- Learning top {args.svd_threshold * 100:.1f}% singular values for each layer")
    print(f"- merge_task_vectors Threshold: {args.svd_threshold * 100:.1f}% of singular values")
    print(f"- merge_task_vectors Mode: {'KEEP top values, ZERO OUT the rest' if args.keep_top_values else 'ZERO OUT top values, KEEP the rest'}")
    print("=" * 80 + "\n")
    
    
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
