"""Given an objective, learn the coefficients and singular values 
on SVD components of merged task vectors for a dataset.

Features:
- Uses merge_task_vectors to precompute and select SVD components (U, S, Vh)
- Learns selected singular values and layer-wise coefficients
- Maintains original parameter dtype and shape information
"""

import os
import time
import json
import torch
import torchvision
import sys
import argparse

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel
from src.composition import TopValuesTaskBasedSVDWeightedImageEncoder
from src.composition import MergedTaskVectorImageEncoder
from torch import nn

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
     
    if x is None or p is None:
        return 0
     
    if not x.requires_grad:
        return 0
    return gamma * torch.norm(x, p=p, dim=0).mean()

def log_gpu_memory(log_point_name="", num_datasets=None):
    if torch.cuda.is_available() and is_main_process():
        allocated = torch.cuda.memory_allocated() / (1024**2)
        reserved = torch.cuda.memory_reserved() / (1024**2)
        peak = torch.cuda.max_memory_allocated() / (1024**2)
        
        log_message = (
            f"[MEM_LOG | {log_point_name}] "
            f"Allocated: {allocated:.2f} MB | "
            f"Reserved: {reserved:.2f} MB | "
            f"Peak Allocated: {peak:.2f} MB"
        )
        print(log_message)
        
        if wandb.run:
            log_data = {
                f"mem_allocated_mb_{log_point_name}": allocated,
                f"mem_reserved_mb_{log_point_name}": reserved,
                f"mem_peak_mb_{log_point_name}": peak,
            }
            if num_datasets is not None:
                log_data["num_source_tasks_for_mem_log"] = num_datasets
            wandb.log(log_data)

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")

def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    port_list = list(port_list)
    random.shuffle(port_list)
    
    for port in port_list:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

def merge_task_vectors(base_model_params_dict, task_vectors, config):
    """
    Merges task vectors by collecting all SVD components from all tasks for each layer,
    globally sorting them by singular value, and selecting a subset based on a threshold
    relative to a single task's component count.

    For each 2D parameter:
    1. For each task, perform SVD on the task's delta matrix.
    2. Collect all singular components (U, S, Vh) from all tasks into a single list for the layer.
    3. Sort this global list of components based on their singular values.
    4. Select the top K components, where K is determined by `config.svd_threshold`
       applied to the number of components from a *single* task.
    5. The selected U, S, and Vh components are returned directly.
    Non-2D layers' deltas are averaged.
    
    Returns:
        dict: Layer components with keys:
            - 'U', 'S', 'Vh', 'is_svd' (True) - for SVD layers
            - 'tensor', 'is_svd' (False) - for non-SVD layers
            - Metadata: original_dtype, original_shape, num_selected_components
    """
    if not task_vectors:
        print("Warning: No task vectors provided to merge_task_vectors.")
        return None

    print(f"Computing global SVD component selection for {len(task_vectors)} task vectors...")
    
    with torch.no_grad():
        new_vector = {}
        for task_key in task_vectors[0].vector:
            base_key = task_key
            if task_key.startswith("model."):
                base_key = task_key[len("model."):]

            if base_key not in base_model_params_dict:
                print(f"Warning: Base key '{base_key}' (from task key '{task_key}') not in base_model_params_dict. Skipping.")
                continue

            current_device = base_model_params_dict[base_key].device
            original_dtype = task_vectors[0].vector[task_key].dtype
            original_shape = task_vectors[0].vector[task_key].shape

            SKIP_PARAMS = [
                "positional_embedding", "visual.positional_embedding",
                "visual.proj", "token_embedding", "visual.proj", "text_projection", "token_embedding.weight",
                "model.positional_embedding", "model.visual.positional_embedding", "model.token_embedding.weight",
                "model.visual.proj", "model.token_embedding", "model.visual.proj", "model.text_projection"
            ]

            if len(original_shape) != 2 or "text_projection" in task_key or task_key in SKIP_PARAMS:
                tensor_val = None
                if len(task_vectors) == 1:
                    tensor_val = task_vectors[0].vector[task_key].to(current_device)
                else:
                    tvs = [tv.vector[task_key].to(current_device) for tv in task_vectors]
                    tensor_val = sum(tvs) / len(tvs)
                new_vector[task_key] = {
                    'tensor': tensor_val, 'is_svd': False,
                    'original_dtype': original_dtype, 'original_shape': original_shape
                }
                continue

            try:
                all_components = []
                for task_idx, task_vector in enumerate(task_vectors):
                    delta_matrix = task_vector.vector[task_key].to(device=current_device, dtype=torch.float32)
                    U_task, S_task, Vh_task = torch.linalg.svd(delta_matrix, full_matrices=False)
                    
                    for i in range(S_task.shape[0]):
                        all_components.append({
                            "s_value": S_task[i],
                            "u_vector": U_task[:, i],
                            "vh_vector": Vh_task[i, :],
                            "source_task_idx": task_idx
                        })
                
                if not all_components:
                    print(f"Warning: No SVD components generated for key {task_key}. Skipping.")
                    continue

                all_components.sort(key=lambda x: x["s_value"], reverse=config.sorting_descending)
                
                num_to_keep = 76
                num_to_keep = min(num_to_keep, len(all_components))

                if num_to_keep == 0 and config.svd_threshold > 0:
                    print(f"Warning for {task_key}: num_to_keep is 0 with svd_threshold {config.svd_threshold}. Model will handle 0 components.")

                selected_components = all_components[:num_to_keep]
                kept_component_count = len(selected_components)

                if kept_component_count == 0:
                    print(f"Warning: No singular value components selected for key {task_key}. Creating zero SVD components.")
                    U_final = torch.empty((original_shape[0], 0), dtype=torch.float32, device=current_device)
                    S_final = torch.empty((0,), dtype=torch.float32, device=current_device)
                    Vh_final = torch.empty((0, original_shape[1]), dtype=torch.float32, device=current_device)
                else:
                    S_final = torch.stack([comp["s_value"] for comp in selected_components])
                    U_final = torch.stack([comp["u_vector"] for comp in selected_components], dim=1)
                    Vh_final = torch.stack([comp["vh_vector"] for comp in selected_components], dim=0)

                new_vector[task_key] = {
                    'U': U_final.to(original_dtype),
                    'S': S_final.to(original_dtype),
                    'Vh': Vh_final.to(original_dtype),
                    'is_svd': True,
                    'original_dtype': original_dtype,
                    'original_shape': original_shape,
                    'num_selected_components': kept_component_count
                }
                
                print(f"Global SVD components extracted for key '{task_key}', original shape: {original_shape}")
                print(f"  Total components from all tasks: {len(all_components)}")
                print(f"  Selected {kept_component_count} components (fixed selection).")
                print(f"  Final U shape: {U_final.shape}, Final S shape: {S_final.shape}, Final Vh shape: {Vh_final.shape}")

            except Exception as e:
                print(f"Error: Global SVD component selection failed for key '{task_key}' with error: {str(e)}")
                import traceback
                traceback.print_exc()

    return new_vector

def create_task_vector_from_merged(merged_vector, task_vectors, args):
    """(Note: Primarily for compatibility - expects reconstructed tensors)
    Creates TaskVector from traditional merged vector format.
    May require reconstruction from components for new merge_task_vectors outputs.
    """
    if args.finetuning_mode == "linear":
        return LinearizedTaskVector(vector=merged_vector, use_half=not args.no_use_half)
    else:
        return NonLinearTaskVector(vector=merged_vector, use_half=not args.no_use_half)

class LearnableSingularValuesMergedEncoder(nn.Module):
    """Learns singular values for precomputed SVD components.
    
    Uses components from merge_task_vectors:
    - Stores U/Vh as buffers
    - Initializes selected S values and makes them learnable
    - Handles non-SVD layers through direct delta tensors
    
    Args:
        merged_vector_components: Output from merge_task_vectors containing:
            - SVD components for 2D layers
            - Direct tensors for other layers
    """
    def __init__(self, model, merged_vector_components, args):
        super().__init__()
        
        self.model = model
        self.args = args
        
        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir
        
        from functorch import make_functional_with_buffers
        func, params_from_functional, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = nn.ParameterList(params_from_functional)
        for p in self.params:
            p.requires_grad = False
            
        self.param_names = [name for name, _ in model.named_parameters()]

        self.svd_components_info = {}
        self.learnable_s_values = nn.ParameterDict()
        self.direct_deltas = {}
        self.total_learnable_sv = 0
        
        print("\n--- Initializing LearnableSingularValuesMergedEncoder ---")
        print("Processing parameters based on pre-computed SVD components or direct deltas:")
        
        for param_idx, param_name in enumerate(self.param_names):
            if param_name in merged_vector_components:
                layer_data = merged_vector_components[param_name]
                svd_key_safe_name = param_name.replace('.', '_')
                
                original_dtype = layer_data['original_dtype']
                original_shape = layer_data['original_shape']

                if layer_data['is_svd']:
                    U_from_isoc = layer_data['U']
                    S_from_isoc = layer_data['S']
                    Vh_from_isoc = layer_data['Vh']
                    num_components_from_isoc = layer_data['num_selected_components']

                    if num_components_from_isoc > 0:
                        print(f"  SVD layer {param_name}: Reconstructing delta from {num_components_from_isoc} components...")
                        reconstructed_delta = U_from_isoc.to(torch.float32) @ torch.diag_embed(S_from_isoc.to(torch.float32)) @ Vh_from_isoc.to(torch.float32)
                        
                        print(f"  SVD layer {param_name}: Performing second SVD on reconstructed delta...")
                        U_new, S_new, Vh_new = torch.linalg.svd(reconstructed_delta, full_matrices=False)
                        
                        num_total_new_components = S_new.shape[0]
                        num_to_learn = int(self.args.svd_threshold * num_total_new_components)

                        S_initial_learnable = S_new[:num_to_learn]
                        S_initial_frozen = S_new[num_to_learn:]
                        
                        self.register_buffer(f'U_{svd_key_safe_name}', U_new)
                        self.register_buffer(f'Vh_{svd_key_safe_name}', Vh_new)
                        
                        self.register_buffer(f'initial_selected_S_{svd_key_safe_name}', S_initial_learnable)
                        self.register_buffer(f'frozen_S_{svd_key_safe_name}', S_initial_frozen)
                        
                        learnable_S_for_layer = torch.full_like(S_initial_learnable, 0.0, dtype=torch.float32)
                        self.learnable_s_values[svd_key_safe_name] = nn.Parameter(learnable_S_for_layer)
                        
                        self.total_learnable_sv += num_to_learn
                        
                        self.svd_components_info[param_name] = {
                            'original_dtype': original_dtype,
                            'original_shape': original_shape,
                            'num_learnable': num_to_learn,
                            'num_frozen': len(S_initial_frozen),
                            'num_total_components': num_total_new_components,
                        }
                        print(f"  SVD layer {param_name} (Shape {original_shape}): "
                              f"Re-SVD created {num_total_new_components} new components. "
                              f"Learning top {num_to_learn} ({self.args.svd_threshold*100:.1f}%) singular values. "
                              f"Freezing remaining {len(S_initial_frozen)}.")
                    else:
                        print(f"  SVD layer {param_name} (Shape {original_shape}, Dtype {original_dtype}): "
                              f"No SVD components selected by merge_task_vectors. Delta will be zero.")
                        self.svd_components_info[param_name] = { 
                            'original_dtype': original_dtype,
                            'original_shape': original_shape,
                            'num_learnable': 0,
                            'num_frozen': 0,
                            'num_total_components': 0,
                        }
                        device = self.params[param_idx].device
                        self.register_buffer(f'U_{svd_key_safe_name}', torch.empty((original_shape[0], 0), dtype=torch.float32, device=device))
                        self.register_buffer(f'Vh_{svd_key_safe_name}', torch.empty((0, original_shape[1]), dtype=torch.float32, device=device))
                        self.register_buffer(f'frozen_S_{svd_key_safe_name}', torch.empty((0,), dtype=torch.float32, device=device))
                else:
                    direct_delta_tensor = layer_data['tensor']
                    self.register_buffer(f'direct_delta_{svd_key_safe_name}', direct_delta_tensor)
                    self.direct_deltas[param_name] = {
                         'original_dtype': original_dtype,
                         'key_for_buffer': f'direct_delta_{svd_key_safe_name}'
                    }
                    print(f"  Non-SVD layer {param_name} (Shape {direct_delta_tensor.shape if isinstance(direct_delta_tensor, torch.Tensor) else 'N/A'}, Dtype {direct_delta_tensor.dtype if isinstance(direct_delta_tensor, torch.Tensor) else 'N/A'}): Using direct delta.")
            else:
                print(f"  No SVD components or direct delta in merged_vector_components for {param_name}. Update for this layer will be scaled by multiplier but start at zero.")
        
        print(f"Total learnable singular values across all layers: {self.total_learnable_sv}")
        print("--- Initialization Complete --- \n")
        

    def _apply(self, fn):
        """Override method to relocate buffer list"""
        new_self = super()._apply(fn=fn)
        
        if hasattr(new_self, 'buffer') and new_self.buffer is not None:
            new_self.buffer = tuple(fn(b) for b in new_self.buffer)

        if hasattr(new_self, 'params') and isinstance(new_self.params, nn.ParameterList):
            for i in range(len(new_self.params)):
                if new_self.params[i] is not None:
                     new_self.params[i].data = fn(new_self.params[i].data)


        return new_self
    
    def train(self, mode=True):
        super().train(mode)

    def forward(self, x):
        """Forward pass applying the merged vector with learnable singular values to the model parameters."""
        final_model_params = []
        for param_idx, param_name in enumerate(self.param_names):
            base_param = self.params[param_idx]
            
            actual_delta_c: torch.Tensor
            svd_key_safe_name = param_name.replace('.', '_')

            if param_name in self.svd_components_info:
                svd_info = self.svd_components_info[param_name]
                num_total_components = svd_info.get('num_total_components', 0)
                original_delta_dtype = svd_info['original_dtype']

                if num_total_components > 0:
                    U = getattr(self, f'U_{svd_key_safe_name}')
                    Vh = getattr(self, f'Vh_{svd_key_safe_name}')
                    
                    learnable_S_values = self.learnable_s_values.get(svd_key_safe_name)
                    frozen_S_values = getattr(self, f'frozen_S_{svd_key_safe_name}', None)
                    
                    s_parts = []
                    if learnable_S_values is not None and learnable_S_values.numel() > 0:
                        s_parts.append(learnable_S_values)
                    if frozen_S_values is not None and frozen_S_values.numel() > 0:
                        s_parts.append(frozen_S_values)
                    
                    if not s_parts:
                        reconstructed_delta = torch.zeros(svd_info['original_shape'], device=base_param.device)
                    else:
                        full_S_vector = torch.cat(s_parts, dim=0)

                        U_f32 = U.to(torch.float32)
                        Vh_f32 = Vh.to(torch.float32)
                        full_S_f32 = full_S_vector.to(torch.float32)
                        
                        reconstructed_delta = U_f32 @ torch.diag_embed(full_S_f32) @ Vh_f32
                    
                    actual_delta_c = reconstructed_delta.to(original_delta_dtype)
                else:
                    actual_delta_c = torch.zeros_like(base_param)
            
            elif f'direct_delta_{svd_key_safe_name}' in self._buffers:
                direct_delta_tensor = getattr(self, f'direct_delta_{svd_key_safe_name}')
                actual_delta_c = direct_delta_tensor
            else:
                actual_delta_c = torch.zeros_like(base_param)

            final_model_params.append(base_param + actual_delta_c)
        
        return self.func(final_model_params, self.buffer, x)

def main(rank, args):
    args.rank = rank
    
    distributed_initialized = False
    
    try:
        available_ports = list(range(29520, 29590))
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        print("no_use_half", args.no_use_half)
        if hasattr(args, 'isoc') and args.isoc and not args.no_use_half:
            if is_main_process():
                print("\n")
                print("*" * 80)
                print("WARNING: requires full precision for SVD operations.")
                print("Automatically switching to full precision mode (--no-use-half).")
                print("*" * 80)
                print("\n")
            args.no_use_half = True
        
        if args.seed is not None:
            set_seed(args.seed)
        
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]

        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()
        
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx >= args.end_index:
                print(f"Ending iteration {idx+1} because end-index is {args.end_index}")
                break

            try:
                args.datasets = original_datasets[:idx+1]

                if is_main_process():
                    precision_mode = "half precision (float16)" if not args.no_use_half else "full precision (float32)"
                    print(f"Creating task vectors using {precision_mode}")
                    print(f"SVD learning mode: learning top {args.svd_threshold*100:.1f}% singular values for each layer")
                    if args.svd_threshold > 0 and args.no_use_half:
                        print(f"Using SVD thresholding (in merge_task_vectors) with threshold {args.svd_threshold}")
                        if args.keep_top_values:
                            print(f"Mode: KEEP top {args.svd_threshold*100}% singular values, ZERO OUT the rest")
                        else:
                            print(f"Mode: ZERO OUT top {args.svd_threshold*100}% singular values, KEEP the rest")
                    elif args.svd_threshold > 0 and not args.no_use_half:
                        print(f"WARNING: SVD thresholding requires full precision. Use --no-use-half flag for SVD operations.")

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning SVD-merged task vector singular values on {args.target_dataset} with {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                if "CUDA out of memory" in str(e):
                    print("CUDA out of memory error detected - stopping the entire job!")
                    if distributed_initialized:
                        cleanup_ddp()
                    import sys
                    sys.exit(1)
            
    finally:
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    args.datasets = original_datasets


def train(task_vectors, args):
    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    comp_acc = {}
    
    if args.seed is not None:
        set_seed(args.seed + args.rank)

    target_dataset = args.target_dataset
    
    task_vectors_list = [v for k, v in task_vectors.items() if target_dataset.replace("Val", "") != k]
    num_source_tasks = len(task_vectors_list)

    if args.finetuning_mode == "linear":
        if is_main_process():
            print("Error: Linear mode with singular value learning not implemented.")
            if wandb.run is not None:
                wandb.finish(quiet=True)
        return
    else:
        base_image_encoder = ImageEncoder(args)

    base_model_params_dict = {
        name: tensor.clone() 
        for name, tensor in base_image_encoder.model.state_dict().items()
    }
    
    if is_main_process():
        print("\n" + "="*50)
        print("LEARN MODE: Merging task vectors with SVD-based component selection")
        print(f"SVD performed on (base_param + task_delta) for each task.")
        print(f"Learning top {args.svd_threshold*100:.1f}% singular values for each merged layer")
        print("="*50 + "\n")
    
    merged_vector_components = merge_task_vectors(base_model_params_dict, task_vectors_list, args)
    if is_main_process():
        log_gpu_memory("after_isoc", num_datasets=num_source_tasks)
    
    if merged_vector_components is None:
        if is_main_process():
            print("Error: Failed to merge task vectors with SVD.")
            if wandb.run is not None:
                wandb.finish(quiet=True)
        return
    
    image_encoder = LearnableSingularValuesMergedEncoder(
        base_image_encoder, 
        merged_vector_components,
        args
    )
    if is_main_process():
        log_gpu_memory("after_encoder_init", num_datasets=num_source_tasks)

    if is_main_process():
        gpu_info = {}
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        api_key = os.environ.get("WANDB_API_KEY", "your-wandb-key")
        wandb.login(key=api_key)
        
        threshold_pct = int(args.svd_threshold * 100)
        run_name = f"{args.model}_{target_dataset}-isoc-learn-svd-top{threshold_pct}pct-e{args.epochs}-nbdts{len(args.datasets)}"
        
        if hasattr(args, 'svd_threshold') and args.svd_threshold > 0:
            merge_task_vectors_threshold_pct = int(args.svd_threshold * 100)
            merge_task_vectors_zeroing_mode = "keepTop" if args.keep_top_values else "zeroTop"
            run_name += f"-merge_task_vectors_{merge_task_vectors_zeroing_mode}{merge_task_vectors_threshold_pct}pct"
        
        run_name += f"-{subsample_str}"
        
        wandb.init(
            project="atlas2",
            entity="example-owner",
            name=run_name,
            config={
                "type": "learn_coef",
                "genre": "isoc_svd_components_learning",
                "kind": f"axis_top_global_main_{int(args.svd_threshold*100)}pct-sortM{args.sorting_descending}",
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "subsample": args.subsample,
                "merge_task_vectors_svd_threshold": args.svd_threshold,
                "merge_task_vectors_keep_top_values": args.keep_top_values,
                "learnable_sv_percentage": args.svd_threshold * 100,
                "total_learnable_singular_values": image_encoder.total_learnable_sv if hasattr(image_encoder, 'total_learnable_sv') else 0, 
                "learning_approach": "merge_task_vectors_selected_svd_components_with_learnable_S",
                
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,
                
                "port_selection_method": "auto_find_available",
                "sorting_descending": args.sorting_descending,
            }
        )
        
        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
            print("-" * 50 + "\n")
        
    
        initial_selected_sv_log = {}
        for param_name in image_encoder.param_names:
            svd_key_safe_name = param_name.replace('.', '_')
            initial_s_buffer_name = f'initial_selected_S_{svd_key_safe_name}'
            if hasattr(image_encoder, initial_s_buffer_name):
                initial_s_values = getattr(image_encoder, initial_s_buffer_name)
                if initial_s_values.numel() > 0:
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_count"] = initial_s_values.numel()
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_mean"] = initial_s_values.mean().item()
                    initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_max"] = initial_s_values.max().item()
                    for i, val in enumerate(initial_s_values[:5].tolist()):
                        initial_selected_sv_log[f"initial_selected_sv_{svd_key_safe_name}_idx{i}"] = val
        if initial_selected_sv_log:
            wandb.log(initial_selected_sv_log)
    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()
    if is_main_process():
        log_gpu_memory("model_on_cuda", num_datasets=num_source_tasks)

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])
    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=2,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=True,
        output_device=args.rank,
    )
    if is_main_process():
        torch.cuda.reset_peak_memory_stats()
        log_gpu_memory("after_ddp_init", num_datasets=num_source_tasks)

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    
    zeroing_mode = "keepTop" if args.keep_top_values else "zeroTop"
    svd_suffix = f"-{zeroing_mode}{int(args.svd_threshold * 100)}pct"
    
    learn_sv_suffix = f"-learnTop{int(args.svd_threshold * 100)}pct"
    log_path = os.path.join('results', f"slurm-{slurm_job_id}-{target_dataset}_isoc_learned_svd{learn_sv_suffix}{svd_suffix}_{subsample_str}.json")
    print("log_path", log_path)

    scaler = GradScaler()
    if is_main_process():
        wandb.log({
            "target_dataset": str(target_dataset),
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    num_trainable_sv = 0
    if hasattr(image_encoder, 'total_learnable_sv'):
        num_trainable_sv = image_encoder.total_learnable_sv
        
    num_trainable_params = num_trainable_sv
    total_params_in_ddp_model = sum(p.numel() for p in ddp_model.parameters())
    
    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters in DDP model (includes frozen base + learnable): {total_params_in_ddp_model:,}")
        print(f"Trainable singular values: {num_trainable_sv:,}")
        print(f"Total trainable parameters: {num_trainable_params:,}")
        
        original_model_params_count = sum(p.numel() for p in image_encoder.params)
        percentage_trainable_vs_base = 0
        if original_model_params_count > 0 :
             pass

        if total_params_in_ddp_model > 0:
            print(f"Percentage trainable of DDP model: {100 * num_trainable_params / total_params_in_ddp_model:.4f}%\n")

        wandb.config.update({
            "total_parameters_ddp_model": total_params_in_ddp_model,
            "trainable_singular_values": num_trainable_sv,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_percentage_of_ddp": (100 * num_trainable_params / total_params_in_ddp_model) if total_params_in_ddp_model > 0 else 0
        })
        
        wandb.log({
            "total_parameters_ddp_model": total_params_in_ddp_model,
            "trainable_singular_values": num_trainable_sv,
            "trainable_parameters_total": num_trainable_params,
            "trainable_parameters_percentage_of_ddp": (100 * num_trainable_params / total_params_in_ddp_model) if total_params_in_ddp_model > 0 else 0
        })

    for epoch in range(args.epochs):
        epoch_loss = 0.0
        epoch_steps = 0
        epoch_total_correct = 0
        epoch_total_samples = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)
                
                preds = logits.argmax(dim=-1)
                correct_in_batch = (preds == labels).sum().item()
                total_in_batch = labels.size(0)
                batch_training_accuracy = correct_in_batch / total_in_batch

                sv_reg = 0.0
                for param_name, param in ddp_model.module.image_encoder.learnable_s_values.items():
                    sv_reg += torch.norm(param, p=2)
                sv_reg = args.lp_reg * sv_reg if args.lp_reg is not None else 0
                loss = loss + sv_reg / args.num_grad_accumulation
                
                loss = loss / args.num_grad_accumulation
                
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            torch.cuda.empty_cache()

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",
                    flush=True,
                )
                print(f"Batch Training Accuracy: {batch_training_accuracy*100:.2f}%")
                
                batch_log = {
                    "dataset": str(target_dataset),
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "singular_value_reg": sv_reg.item() if hasattr(sv_reg, 'item') else sv_reg,
                    "batch_training_accuracy": batch_training_accuracy * 100,
                }
                
                wandb.log(batch_log)

            if 'process_epoch_correct' not in locals():
                process_epoch_correct = 0
                process_epoch_samples = 0
            process_epoch_correct += correct_in_batch
            process_epoch_samples += total_in_batch

            if (i + 1) >= 2:
                if is_main_process():
                    print("Completed 2 batches. Logging memory and exiting training.")
                    log_gpu_memory("after_2_batches", num_datasets=num_source_tasks)
                
                if is_main_process() and wandb.run:
                    wandb.finish(quiet=True)
                return

        if is_main_process():
            tensor_process_epoch_correct = torch.tensor(process_epoch_correct, dtype=torch.float32, device=args.rank)
            tensor_process_epoch_samples = torch.tensor(process_epoch_samples, dtype=torch.float32, device=args.rank)

            torch.distributed.all_reduce(tensor_process_epoch_correct, op=torch.distributed.ReduceOp.SUM)
            torch.distributed.all_reduce(tensor_process_epoch_samples, op=torch.distributed.ReduceOp.SUM)

            global_epoch_correct = tensor_process_epoch_correct.item()
            global_epoch_samples = tensor_process_epoch_samples.item()

            epoch_training_accuracy = 0
            if global_epoch_samples > 0:
                epoch_training_accuracy = (global_epoch_correct / global_epoch_samples) * 100
            else:
                print("Warning: global_epoch_samples is 0. Cannot compute epoch training accuracy.")

            process_epoch_correct = 0
            process_epoch_samples = 0

            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
                "epoch_training_accuracy": epoch_training_accuracy,
            })
            
            print(f"Epoch {epoch}: Average Loss: {avg_epoch_loss:.4f}, Epoch Training Accuracy: {epoch_training_accuracy:.2f}%")

            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            epoch_log = {
                "dataset": str(target_dataset),
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            for key, param in image_encoder.learnable_s_values.items():
                values = param.detach().cpu().tolist()
                for i, val in enumerate(values[:5]):
                    epoch_log[f"sv_{key}_idx{i}_value"] = val 
                
                epoch_log[f"sv_{key}_mean"] = param.detach().mean().item()
                epoch_log[f"sv_{key}_max"] = param.detach().max().item()

            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%")

    if is_main_process():
        print("\nEvaluating model at the end of training (last epoch state)...")
        image_encoder_last_state = ddp_model.module.image_encoder
        
        last_s_values_sum = 0.0
        if hasattr(image_encoder_last_state, 'learnable_s_values'):
            for param in image_encoder_last_state.learnable_s_values.values():
                last_s_values_sum += param.detach().sum().item()
        
        print(f"[DEBUG LAST STATE] Sum of learnable_s_values: {last_s_values_sum}")
        wandb.log({
            "debug_last_state_s_values_sum": last_s_values_sum
        })

        test_metrics_last_model = eval_single_dataset(image_encoder_last_state, target_dataset.replace("Val", ""), args)
        final_test_acc_last_model = test_metrics_last_model["top1"]
        
        print(f"Final test accuracy (last model state) on {target_dataset.replace('Val', '')}: {100 * final_test_acc_last_model:.2f}%")
        wandb.log({
            "dataset": str(target_dataset.replace('Val', '')), 
            "final_test_accuracy_last": 100 * final_test_acc_last_model,
        })

        target_dataset_test = target_dataset.replace("Val", "")
        
        comp_acc[target_dataset_test] = final_test_acc_last_model
        
        wandb.log({
            "dataset": str(target_dataset_test), 
            "final_test_accuracy": 100 * final_test_acc_last_model,
        })
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        test_acc_path = os.path.join('results',
                                     f"slurm-{slurm_job_id}-{target_dataset_test}_isoc_learned_svd{learn_sv_suffix}{svd_suffix}_{subsample_str}_accuracy.json")
        
        test_acc_data = {
            "target_dataset": str(target_dataset_test),
            "number_of_used_datasets": len(args.datasets),
            "datasets_names": args.datasets,
            "test_accuracy_last_model": float(final_test_acc_last_model),
            "test_metrics_last_model": test_metrics_last_model,
            "merge_task_vectors_svd_threshold": args.svd_threshold if hasattr(args, 'svd_threshold') else 0.0,
            "merge_task_vectors_keep_top_values": args.keep_top_values if hasattr(args, 'keep_top_values') else False,
            "merge_task_vectors_zeroing_mode": "keep_top" if (hasattr(args, 'keep_top_values') and args.keep_top_values) else "zero_top",
            "merge_task_vectors_zeroing_percentage": int(args.svd_threshold * 100) if hasattr(args, 'svd_threshold') else 0,
            "learnable_sv_percentage": args.svd_threshold * 100,
            "initial_selected_singular_values_from_isoc": {}
        }
        
        final_sv_values = {}
        for key, param in image_encoder.learnable_s_values.items():
            values = param.detach().cpu().tolist()
            sv_with_indices = {f"idx{i}": val for i, val in enumerate(values)}
            final_sv_values[key] = sv_with_indices 
            
            for i, val in enumerate(values):
                wandb.log({f"final_sv_{key}_idx{i}_value": val})
                
        test_acc_data["learned_singular_values"] = final_sv_values
        
        initial_selected_sv_data_for_json = {}
        for param_name in image_encoder.param_names:
            svd_key_safe_name = param_name.replace('.', '_')
            initial_s_buffer_name = f'initial_selected_S_{svd_key_safe_name}'
            if hasattr(image_encoder, initial_s_buffer_name):
                initial_s_values = getattr(image_encoder, initial_s_buffer_name)
                if initial_s_values.numel() > 0:
                    selected_sv_with_indices = {f"idx{i}": val for i, val in enumerate(initial_s_values.tolist())}
                    initial_selected_sv_data_for_json[f"initial_selected_sv_{svd_key_safe_name}"] = selected_sv_with_indices
        test_acc_data["initial_selected_singular_values_from_isoc"] = initial_selected_sv_data_for_json

        if os.path.exists(test_acc_path):
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            existing_data.append(test_acc_data)
            
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}")
     
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}")
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)")
        
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    args = parse_arguments()
    
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("to ensure that experiment code states are tracked.")
    print("=" * 80)
    
    args.datasets = target_datasets
    args.lr = 1e-1
    args.epochs = 10
    
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1
    args.print_every = 10
    
    args.seed = 0
    args.save = args.save + f"{args.model}"
    
    if not args.no_use_half:
        print(f"SVD operations require full precision, automatically enabling --no-use-half=True")
        args.no_use_half = True
    
    args.isoc = True
    args.use_svd = True
    args.keep_top_values = True
    args.sorting_descending = True

    
    print("\n" + "=" * 80)
    print("SVD LEARNING & THRESHOLDING CONFIGURATION:")
    print(f"- Learning top {args.svd_threshold * 100:.1f}% singular values for each layer")
    print(f"- merge_task_vectors Threshold: {args.svd_threshold * 100:.1f}% of singular values")
    print(f"- merge_task_vectors Mode: {'KEEP top values, ZERO OUT the rest' if args.keep_top_values else 'ZERO OUT top values, KEEP the rest'}")
    print("=" * 80 + "\n")
    
    
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
