import os
import time
import json
import torch
import torchvision
import sys

# Set environment variables to suppress all wandb output
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

# Redirect wandb stderr to /dev/null before import
old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
    return 0 if p is None else gamma * torch.norm(x, p=p, dim=0).mean()

def log_gpu_memory(log_point_name="", num_datasets=None):
    if torch.cuda.is_available() and is_main_process():
        allocated = torch.cuda.memory_allocated() / (1024**2)
        reserved = torch.cuda.memory_reserved() / (1024**2)
        peak = torch.cuda.max_memory_allocated() / (1024**2)
        
        log_message = (
            f"[MEM_LOG | {log_point_name}] "
            f"Allocated: {allocated:.2f} MB | "
            f"Reserved: {reserved:.2f} MB | "
            f"Peak Allocated: {peak:.2f} MB"
        )
        print(log_message)
        
        if wandb.run:
            log_data = {
                f"mem_allocated_mb_{log_point_name}": allocated,
                f"mem_reserved_mb_{log_point_name}": reserved,
                f"mem_peak_mb_{log_point_name}": peak,
            }
            if num_datasets is not None:
                log_data["num_source_tasks_for_mem_log"] = num_datasets
            wandb.log(log_data)

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For completely deterministic results, set the following flags
    # Note: This may slow down training
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")


def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    # Shuffle the port list to avoid always trying the same ports first
    port_list = list(port_list)  # Create a copy to avoid modifying the original
    random.shuffle(port_list)
    
    for port in port_list:
        # Check with TCP socket
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                # Double check by waiting a moment and trying again
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

def main(rank, args):
    # First set up distributed processing
    args.rank = rank
    
    # Track if distributed is initialized
    distributed_initialized = False
    
    try:
        # Use a different port to avoid "Address already in use" errors
        # Define a list of available ports to choose from
        available_ports = list(range(29520, 29590))
        # Use the port from args if specified, otherwise find first available port
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            # Try to find an available port
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        # Initialize distributed processing
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        
        # Then set the random seed for reproducibility
        if args.seed is not None:
            set_seed(args.seed)
        
        # Load the individual task vectors.
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]
        # task_vectors = {}
        # for dataset in pool:
        #     if args.finetuning_mode == "linear":
        #         pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
        #         finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
        #         task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
        #     else:
        #         pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
        #         finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
        #         task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()  # Make a copy to avoid modifying the original list
        
        # Loop over each dataset up to the full set
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx >= args.end_index:
                print(f"Skipping iteration {idx+1} because end-idx is {args.end_index}")
                continue

            try:
                # For each iteration, use datasets from 0 to idx (inclusive)
                args.datasets = original_datasets[:idx+1]

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning task vector coefficients on {args.target_dataset} with {args.model} using {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                # Print all command line arguments
                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                print(f"Continuing with next iteration...")
                continue
            
    finally:
        # Only cleanup distributed processing at the very end
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    # Restore the original datasets after all iterations are complete
    args.datasets = original_datasets

# === Custom Weighted Image Encoder Class ===
from torch import nn
from functorch import make_functional_with_buffers

class CustomWeightedImageEncoder(nn.Module):
    def __init__(self, model, task_vectors, total_trainable_params, num_task_vectors, num_blocks) -> None:
        """A custom wrapper class for task vector composition based on total trainable parameters.

        Args:
            model: Base image encoder model.
            task_vectors: List of task vectors.
            total_trainable_params: The exact number of coefficients to learn.
            num_task_vectors: Number of task vectors being used.
            num_blocks: Number of parameter blocks in the base model.
        """
        super().__init__()

        # Ensure num_task_vectors and num_blocks are positive
        if num_task_vectors <= 0 or num_blocks <= 0:
            raise ValueError(f"num_task_vectors ({num_task_vectors}) and num_blocks ({num_blocks}) must be positive.")

        # Functionalize the base model
        func, params, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = torch.nn.ParameterList(params)
        for p in self.params:
            p.requires_grad = False

        # Store model attributes
        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir

        # Store task vector deltas (dparams)
        # Assuming task_vectors contain NonLinearTaskVector or similar with a .vector dict
        self.dparams = [[tv.vector[k] for k in tv.vector] for tv in task_vectors]

        # Calculate partitioning based on total_trainable_params
        self.num_task_vectors = num_task_vectors
        self.num_blocks = num_blocks
        self.total_trainable_params = total_trainable_params

        params_per_full_partition = self.num_task_vectors * self.num_blocks
        if params_per_full_partition == 0:
             # Handle edge case where there are no params to partition (e.g., empty model or task vectors)
             self.base_partition = 0
             self.remainder_params = 0
             self.max_partition = 1 # Avoid dimension size 0
        else:
            self.base_partition = self.total_trainable_params // params_per_full_partition
            self.remainder_params = self.total_trainable_params % params_per_full_partition
            self.max_partition = self.base_partition + (1 if self.remainder_params > 0 else 0)
            if self.max_partition == 0: # Handle case where total_trainable_params is 0
                self.max_partition = 1

        print(f"[CustomEncoder] Total Params: {self.total_trainable_params}, Tasks: {self.num_task_vectors}, Blocks: {self.num_blocks}")
        print(f"[CustomEncoder] Base Partition: {self.base_partition}, Remainder Params: {self.remainder_params}, Max Partition Dim: {self.max_partition}")

        # Initialize coefficients parameter
        self.coef = torch.nn.Parameter(torch.zeros(self.num_task_vectors, self.num_blocks, self.max_partition))

        # Initialize trainable mask for coefficients
        # This mask determines WHICH coef entries are actually used and trained
        trainable_coef_mask = torch.zeros_like(self.coef, dtype=torch.bool, requires_grad=False)

        # Apply base partition mask to coef mask
        if self.base_partition > 0:
            trainable_coef_mask[:, :, :self.base_partition] = True

        # Distribute remainder parameters to coef mask
        current_remainder = self.remainder_params
        if current_remainder > 0 and self.base_partition < self.max_partition:
            print(f"[CustomEncoder] Distributing {current_remainder} remainder coefficients...")
            coef_idx_to_distribute = self.base_partition # The index within the max_partition dimension of coef
            assigned_count = 0
            # Iterate forwards through blocks
            for block_idx in range(self.num_blocks):
                # Alternate between task vectors
                for task_idx in range(self.num_task_vectors):
                    if current_remainder > 0:
                        trainable_coef_mask[task_idx, block_idx, coef_idx_to_distribute] = True
                        # print(f"  - Assigning coef to Task {task_idx}, Block {block_idx}, Partition Index {coef_idx_to_distribute}")
                        current_remainder -= 1
                        assigned_count += 1
                    else:
                        break # No more remainders left
                if current_remainder <= 0:
                    break # No more remainders left
            print(f"[CustomEncoder] Assigned {assigned_count} remainder coefficients.")

        # Register the coefficient mask as a buffer
        self.register_buffer('trainable_coef_mask', trainable_coef_mask.float()) # Store as float for multiplication

        # --- Create mask_mats for partitioning parameters within blocks ---
        # We store these as buffers, associated with parameter shapes
        print(f"[CustomEncoder] Creating mask_mats buffers for {self.max_partition} partitions...")
        self._mask_mat_buffer_shapes = {} # Store mapping from shape to buffer name
        processed_shapes = set()
        for i, p in enumerate(self.params):
            p_shape = tuple(p.shape)
            if p_shape not in processed_shapes:
                if p.ndim == 0: # Skip scalar parameters if any
                    print(f"    - Skipping scalar parameter at index {i}")
                    continue
                print(f"    - Processing shape {p_shape} (Block {i})")
                # Create partition assignments (indices from 0 to max_partition-1)
                # Ensure assignments are generated on CPU to avoid potential CUDA init issues here
                assignments = torch.randint(0, self.max_partition, p_shape, device='cpu')
                # One-hot encode assignments and move partition dim first
                # Use float() for potential AMP compatibility later
                mask_mat = torch.nn.functional.one_hot(assignments, num_classes=self.max_partition).permute(-1, *range(p.ndim)).float()
                
                # Generate a unique buffer name based on shape
                buffer_name = f"mask_mat_{'_'.join(map(str, p_shape))}"
                # Register as buffer
                self.register_buffer(buffer_name, mask_mat)
                self._mask_mat_buffer_shapes[p_shape] = buffer_name # Store mapping
                processed_shapes.add(p_shape)
                print(f"      - Created buffer '{buffer_name}' with shape: {mask_mat.shape}")
            # else: shape already processed
        print(f"[CustomEncoder] Finished creating mask_mats buffers for {len(processed_shapes)} unique shapes.")
        # Ensure buffers are moved to the correct device later by _apply

    def _apply(self, fn):
        # Manually apply fn to registered buffers if super()._apply doesn't handle them correctly
        # Note: Buffers should ideally be handled by the default _apply
        new_self = super()._apply(fn=fn)
        # Re-apply fn to dparams just in case super() didn't recurse
        new_self.dparams = [[fn(x) if isinstance(x, torch.Tensor) else x for x in tv] for tv in new_self.dparams]
        # Explicitly apply fn to mask buffers just to be safe, though register_buffer should handle this.
        for shape, buffer_name in new_self._mask_mat_buffer_shapes.items():
            if hasattr(new_self, buffer_name):
                 current_buffer = getattr(new_self, buffer_name)
                 if isinstance(current_buffer, torch.Tensor):
                     setattr(new_self, buffer_name, fn(current_buffer))
            else:
                print(f"Warning: Buffer {buffer_name} not found during _apply")
        if hasattr(new_self, 'trainable_coef_mask') and isinstance(new_self.trainable_coef_mask, torch.Tensor):
            new_self.trainable_coef_mask = fn(new_self.trainable_coef_mask)

        return new_self

    def train(self, mode=True):
        super().train(mode)

    def forward(self, x) -> torch.Tensor:
        # Apply the mask to coefficients to get effective coefficients
        # Only non-zero entries in trainable_coef_mask allow gradients through self.coef
        effective_coefs = self.coef * self.trainable_coef_mask # Shape: (T, B, P)

        # Calculate the final delta for each parameter block
        final_dparams = []
        for block_idx, p_orig in enumerate(self.params):
            p_shape = tuple(p_orig.shape)
            
            # Get the corresponding mask_mat buffer for this parameter shape
            if p_shape in self._mask_mat_buffer_shapes:
                buffer_name = self._mask_mat_buffer_shapes[p_shape]
                block_mask_mat = getattr(self, buffer_name) # Shape: (P, *p_shape)
            elif p_orig.ndim == 0:
                 # Handle scalar parameters - they don't have masks/deltas from task vectors typically
                 final_dparams.append(torch.tensor(0.0, device=p_orig.device, dtype=p_orig.dtype)) # Append zero delta
                 continue
            else:
                 # This should not happen if __init__ processed all shapes
                 raise RuntimeError(f"Mask matrix buffer not found for parameter shape {p_shape} at block index {block_idx}")

            # Get deltas for this block from all task vectors
            # dparams format: List[List[Tensor]], shape roughly (T, B)
            block_task_deltas = [self.dparams[task_idx][block_idx] for task_idx in range(self.num_task_vectors)]
            
            # Ensure all deltas are tensors and stack them
            valid_deltas = [delta for delta in block_task_deltas if isinstance(delta, torch.Tensor)]
            if not valid_deltas:
                 # If no valid tensor deltas for this block, append zero delta
                 final_dparams.append(torch.zeros_like(p_orig))
                 continue
            stacked_deltas = torch.stack(valid_deltas, dim=0) # Shape: (T, *p_shape)

            # Get effective coefficients for this block
            block_effective_coefs = effective_coefs[:, block_idx, :] # Shape: (T, P)
            
            # Prepare shapes for broadcasting
            # T = num_task_vectors, B = num_blocks, P = max_partition
            T = stacked_deltas.shape[0]
            P = block_mask_mat.shape[0]
            shape_dims = stacked_deltas.shape[1:]
            view_T_P_1s = (T, P) + (1,) * len(shape_dims)
            view_1_P_shape = (1, P) + shape_dims
            view_T_1_shape = (T, 1) + shape_dims

            # Calculate weighted delta per task, per partition
            # term = coef[t, p] * mask_mat[p, ...] * delta[t, ...]
            term = block_effective_coefs.view(view_T_P_1s) * \
                   block_mask_mat.view(view_1_P_shape) * \
                   stacked_deltas.view(view_T_1_shape)
            
            # Sum over task (T) and partition (P) dimensions to get final delta for the block
            final_delta_b = term.sum(dim=(0, 1)) # Shape: (*p_shape)
            
            final_dparams.append(final_delta_b)

        # Apply the computed deltas to the original parameters
        new_params = [(dp + p) if isinstance(dp, torch.Tensor) else p for dp, p in zip(final_dparams, self.params)]

        # Execute the model function with the modified parameters
        return self.func(new_params, self.buffer, x)

# === End Custom Weighted Image Encoder Class ===

def train(task_vectors, args):
    # Track the experiment start time
    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    # Format subsample value for filename/run name
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    if args.seed is not None:
        set_seed(args.seed + args.rank)  # Add rank to avoid identical random values across processes

    target_dataset = args.target_dataset
    
    # Initialize wandb in the main process (after distributed setup)
    if is_main_process():
        # Collect system information
        gpu_info = {}
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        wandb.login(key="your-wandb-key")
        wandb.init(
            project="axis",
            entity="example-owner",
            name=f"{args.model}_{target_dataset}_{subsample_str}-blkcoef-{args.blockwise_coef}-e{args.epochs}-nbdts{len(args.datasets)}-part{args.partition}",
            config={
                "type": "learn_coef",
                "kind": "baseline-20pct", # update per experiment
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "blockwise_coef": True,
                "subsample": args.subsample,
                
                # System information
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                # Runtime info
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                # Job details
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                # Training details
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,  # Using torch.autocast
                
                # Model details
                "port_selection_method": "auto_find_available",
            }
        )

        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
        print("-" * 50 + "\n")

    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    assert args.finetuning_mode in [
        "linear",
        "standard",
    ], "Only linear and standard fine-tuning are supported."

    linearized_finetuning = args.finetuning_mode == "linear"
    if linearized_finetuning:
        print("Using linearized fine-tuning.")

    orig_dataset = target_dataset.replace("Val", "")
    # Get a list of datasets (without the target dataset)
    pool = [
        "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
        "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
        "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
    ]
    dataset_names = [ds for ds in pool if ds != orig_dataset]
    
    # Remove the task vector for the target task
    task_vectors = [v for k, v in task_vectors.items() if orig_dataset != k]
    
    # Store the mapping of index to dataset name for tracking in wandb
    coef_dataset_mapping = {i: dataset_names[i] for i in range(len(dataset_names))}
    if is_main_process():
        print("Coefficient to dataset mapping:")
        for idx, ds_name in coef_dataset_mapping.items():
            print(f"Coefficient {idx}: {ds_name}")
    
    if args.finetuning_mode == "linear":
        # TODO: Implement custom logic for linearized fine-tuning if needed.
        # For now, assuming total_trainable_params only applies to standard mode.
        print("WARNING: --total-trainable-params currently only supported for standard finetuning mode. Using original WeightedLinearizedModel.")
        image_encoder = LinearizedImageEncoder(args, keep_lang=False)
        image_encoder.model = WeightedLinearizedModel(
            image_encoder.model, task_vectors, blockwise=args.blockwise_coef, partition=args.partition
        )
    # Standard fine-tuning mode
    elif args.total_trainable_params is not None:
        print(f"Using CustomWeightedImageEncoder with total_trainable_params = {args.total_trainable_params}")
        # Create a base encoder instance to get parameters info
        temp_base_encoder = ImageEncoder(args)
        # Functionalize the temporary encoder to get the parameter list
        _, temp_params, _ = make_functional_with_buffers(temp_base_encoder)
        num_blocks = len(temp_params)
        num_task_vectors = len(task_vectors)
        del temp_base_encoder, temp_params # Free memory
        
        # Create the custom encoder
        image_encoder = CustomWeightedImageEncoder(
            model=ImageEncoder(args), # Pass a new instance
            task_vectors=task_vectors,
            total_trainable_params=args.total_trainable_params,
            num_task_vectors=num_task_vectors,
            num_blocks=num_blocks
        )
        print(f"Created CustomWeightedImageEncoder with {num_task_vectors} tasks, {num_blocks} blocks.")
        if is_main_process():
            log_gpu_memory("after_encoder_init", num_datasets=len(task_vectors))


        # Print parameter details for the custom encoder if needed
        if is_main_process():
            print("\nParameters of the CustomWeightedImageEncoder (Trainable Coefs):")
            param_counter = 1
            total_params_in_coef = 0
            for name, param in image_encoder.named_parameters():
                if param.requires_grad:
                    print(f"{param_counter}. {name} (Shape: {list(param.shape)}, Type: {param.dtype}) - Trainable")
                    total_params_in_coef += param.numel()
                    param_counter += 1
            print(f"Total elements in trainable 'coef' parameter: {total_params_in_coef}")
            # Verify against the mask
            if hasattr(image_encoder, 'trainable_coef_mask'):
                num_masked_trainable = image_encoder.trainable_coef_mask.sum().item()
                print(f"Number of trainable coefficients according to mask: {int(num_masked_trainable)}")
                if int(num_masked_trainable) != args.total_trainable_params:
                     print(f"WARNING: Masked trainable count ({int(num_masked_trainable)}) does not match requested total_trainable_params ({args.total_trainable_params})!")
            print("-" * 30)


    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()
    if is_main_process():
        log_gpu_memory("model_on_cuda", num_datasets=len(task_vectors))

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])

    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=4,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    # Printing loss between four and ten times an epoch
    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    # Distribute the data and model across the GPUs.
    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=False,
        output_device=args.rank,
    )
    if is_main_process():
        torch.cuda.reset_peak_memory_stats()
        log_gpu_memory("after_ddp_init", num_datasets=len(task_vectors))

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    # Do not use warm up
    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    # Get SLURM job ID from environment variables (use 'unknown' if not available)
    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    

    # --- Path and Coefficient Handling ---
    if args.total_trainable_params is not None:
        # Custom encoder path and coef handling
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.json")
        # Get the coef parameter (which is 3D) and the mask
        coef_param = ddp_model.module.image_encoder.coef
        coef_mask = ddp_model.module.image_encoder.trainable_coef_mask
        print("Using CustomWeightedImageEncoder - paths set.")
        print(f"Coef parameter shape: {coef_param.shape}, Mask shape: {coef_mask.shape}")
    elif linearized_finetuning:
        # Original linearized paths
        head_path = os.path.join(ckpdir, f"learned_linear_composition_{subsample_str}.pt")
        log_path = os.path.join(args.save, f"learned_linear_composition_{subsample_str}.json")
        coef_param = ddp_model.module.image_encoder.model.coef # Access coef in WeightedLinearizedModel
        coef_mask = None # No mask for linearized or original weighted encoder
        print("Using Linearized fine-tuning - paths set.")
    else:
        # Original standard paths
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.json")
        coef_param = ddp_model.module.image_encoder.coef # Access coef in WeightedImageEncoder
        coef_mask = None # No mask for linearized or original weighted encoder
        print("Using standard WeightedImageEncoder - paths set.")

    print("head_path:", head_path)
    print("log_path:", log_path)
    # Print initial coef value/shape (use .data to avoid grad issues)
    print("Initial coef param shape:", coef_param.shape)
    # --- End Path and Coefficient Handling ---

    scaler = GradScaler()
    if is_main_process():
        print(f"=> Zero-shot accuracy on {target_dataset}:\t{100*args.zs_acc[target_dataset]:.2f}%.")
        # Log zero-shot accuracy to wandb
        wandb.log({
            "target_dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
            "zero_shot_accuracy": 100 * args.zs_acc[target_dataset]
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    # Initialize best_coef with the initial coefficients
    best_coef = coef_param.data.clone()
    best_acc = args.zs_acc[target_dataset]


    # Calculate and log number of trainable parameters
    total_params = sum(p.numel() for p in ddp_model.parameters())
    
    # Adjust calculation of trainable params if using CustomWeightedImageEncoder
    is_custom_encoder = args.total_trainable_params is not None and hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
    if is_custom_encoder:
        # Sum params excluding the coef tensor, then add the masked count
        params_excluding_coef = [p for name, p in ddp_model.named_parameters() if p.requires_grad and 'image_encoder.coef' not in name]
        num_trainable_params_other = sum(p.numel() for p in params_excluding_coef)
        num_trainable_params_coef = int(ddp_model.module.image_encoder.trainable_coef_mask.sum().item())
        num_trainable_params = num_trainable_params_other + num_trainable_params_coef
        print(f"[Param Stats - Custom] Other trainable: {num_trainable_params_other}, Masked coef: {num_trainable_params_coef}, Total effective trainable: {num_trainable_params}")
    else:
        # Original calculation
        num_trainable_params = sum(p.numel() for p in ddp_model.parameters() if p.requires_grad)
        print(f"[Param Stats - Standard] Total trainable: {num_trainable_params}")

    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {num_trainable_params:,}")
        print(f"Percentage trainable: {100 * num_trainable_params / total_params:.2f}%\n")
        # Log parameter counts to wandb config and metrics
        wandb.config.update({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        }, allow_val_change=True)
        
        # Also log as metrics for tracking over time
        wandb.log({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        })

    for epoch in range(args.epochs):
        # Track epoch metrics
        epoch_loss = 0.0
        epoch_steps = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)
                # Apply regularisation if needed.

                reg = lp_reg(coef_param, args.lp_reg)
                loss = loss + reg
                loss = loss / args.num_grad_accumulation
                
                # Track loss for epoch average
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            if (i + 1) >= 2:
                if is_main_process():
                    print("Completed 2 batches. Logging memory and exiting training.")
                    log_gpu_memory("after_2_batches", num_datasets=len(task_vectors))
                
                # Finish wandb run cleanly before exiting
                if is_main_process() and wandb.run:
                    wandb.finish(quiet=True)
                return

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"           # noqa: E501
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",   # noqa: E501
                    flush=True,
                )
                
                # Log batch metrics to wandb
                batch_log = {
                    "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "regularization": reg.item() if reg != 0 else 0,
                }
                
                # Add individual coefficient values to logging
                # Use .data to get tensor without grad_fn
                current_coef_tensor = ddp_model.module.image_encoder.coef.data
                
                # Check if we are using the custom encoder with a mask
                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                    # Iterate through the coef tensor and log only where mask is True
                    for task_idx in range(current_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(current_coef_tensor.size(1)):
                            for part_idx in range(current_coef_tensor.size(2)):
                                if current_mask[task_idx, block_idx, part_idx]:
                                    coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                    batch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}"] = coef_value
                elif args.blockwise_coef:
                    # Original blockwise logging (assuming 2D coef)
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx in range(current_coef_values.size(0)):
                        layer_coefs = current_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                             batch_log[f"coef_layer{layer_idx}_{dataset_name}"] = coef_value.item()
                elif args.partition: # Original partition logging (assuming 3D coef but different structure)
                     current_coef_values = current_coef_tensor.cpu()
                     # Assuming original partition structure: (tasks, blocks, partitions)
                     for dataset_idx in range(current_coef_values.size(0)):
                         dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                         for block_idx in range(current_coef_values.size(1)):
                             for part_idx in range(current_coef_values.size(2)):
                                 coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                                 batch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}"] = coef_value
                else:
                    # Original non-blockwise logging (assuming 1D coef)
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        batch_log[f"coef_{dataset_name}"] = coef_value.item()
                
                wandb.log(batch_log)
        
        # Evaluate after each epoch
        if is_main_process():
            # Log average epoch loss
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
            })
            
            # Evaluate on validation set
            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            # Log validation metrics
            epoch_log = {
                "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            # Add individual coefficient values to epoch logging
            # Use .data to get tensor without grad_fn
            current_coef_tensor = ddp_model.module.image_encoder.coef.data
            
            # Check if we are using the custom encoder with a mask
            is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
            if is_custom_encoder:
                current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                # Iterate and log only masked values
                for task_idx in range(current_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(current_coef_tensor.size(1)):
                        for part_idx in range(current_coef_tensor.size(2)):
                            if current_mask[task_idx, block_idx, part_idx]:
                                coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                epoch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            elif args.blockwise_coef:
                 # Original blockwise logging (assuming 2D coef)
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     layer_coefs = current_coef_values[dataset_idx]
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for layer_idx, coef_value in enumerate(layer_coefs):
                         epoch_log[f"coef_layer{layer_idx}_{dataset_name}_epoch"] = coef_value.item()
            elif args.partition: # Original partition logging
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(current_coef_values.size(1)):
                         for part_idx in range(current_coef_values.size(2)):
                            coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                            epoch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            else:
                # Original non-blockwise logging (assuming 1D coef)
                current_coef_values = current_coef_tensor.cpu()
                for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    epoch_log[f"coef_{dataset_name}_epoch"] = coef_value.item()
                
            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%")
            
            if val_acc > best_acc:
                best_acc = val_acc
                best_coef = coef_param.data.clone()
                torch.save(best_coef, head_path)
                
                # Log best coefficients
                best_log = {
                    "target_dataset": target_dataset,
                    "best_val_accuracy": 100 * best_acc,
                    "best_epoch": epoch,
                }
                
                # Add individual best coefficient values
                best_coef_tensor = best_coef.cpu() # best_coef is already cloned .data

                # Check if we are using the custom encoder with a mask
                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    # We need the mask corresponding to the state when best_coef was saved
                    # Assuming the mask doesn't change during training (it's a buffer)
                    best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                    for task_idx in range(best_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(best_coef_tensor.size(1)):
                            for part_idx in range(best_coef_tensor.size(2)):
                                if best_mask[task_idx, block_idx, part_idx]:
                                    coef_value = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                    best_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                elif args.blockwise_coef:
                    # Original blockwise
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        layer_coefs = best_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                            best_log[f"coef_layer{layer_idx}_{dataset_name}_best"] = coef_value.item()
                elif args.partition:
                    # Original partition
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for block_idx in range(best_coef_values.size(1)):
                            for part_idx in range(best_coef_values.size(2)):
                                coef_value = best_coef_values[dataset_idx, block_idx, part_idx].item()
                                best_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                else:
                    # Original non-blockwise
                    best_coef_values = best_coef_tensor
                    for dataset_idx, coef_value in enumerate(best_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        best_log[f"coef_{dataset_name}_best"] = coef_value.item()
                
                wandb.log(best_log)

    if is_main_process():
        comp_acc[target_dataset] = best_acc
        target_dataset_test = target_dataset.replace("Val", "")
        image_encoder = ddp_model.module.image_encoder
        
        # Set the best coefficients back into the encoder
        is_custom_encoder = hasattr(image_encoder, 'trainable_coef_mask')
        if is_custom_encoder:
            # For custom encoder, best_coef holds the entire 3D tensor
            # Parameter assignment is handled internally via Parameter(best_coef)
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
                print("Loaded best_coef (3D tensor) into CustomWeightedImageEncoder")
            else:
                print("Warning: best_coef is None, cannot load into CustomWeightedImageEncoder")
        elif linearized_finetuning:
            # Original linearized logic
            if best_coef is not None:
                 image_encoder.model.coef = torch.nn.Parameter(best_coef)
            else:
                print("Warning: best_coef is None, cannot load into WeightedLinearizedModel")
        else:
            # Original standard logic
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
            else:
                 print("Warning: best_coef is None, cannot load into WeightedImageEncoder")

        # Print the shape of the coef parameter after loading best_coef
        print("best_coef shape after loading:", image_encoder.coef.shape if best_coef is not None else "N/A")
        
        # Log the final best coefficients with dataset names
        final_coef_log = {}
        if best_coef is not None:
            best_coef_tensor = best_coef.cpu()
            if is_custom_encoder:
                best_mask = image_encoder.trainable_coef_mask.data.bool().cpu()
                for task_idx in range(best_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(best_coef_tensor.size(1)):
                        for part_idx in range(best_coef_tensor.size(2)):
                            if best_mask[task_idx, block_idx, part_idx]:
                                val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                final_coef_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_final"] = val
            elif args.blockwise_coef:
                best_coef_values = best_coef_tensor
                for dataset_idx in range(best_coef_values.size(0)):
                    layer_coefs = best_coef_values[dataset_idx]
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    for layer_idx, val in enumerate(layer_coefs):
                        final_coef_log[f"coef_layer{layer_idx}_{dataset_name}_final"] = val.item()
            elif args.partition:
                 best_coef_values = best_coef_tensor
                 for dataset_idx in range(best_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(best_coef_values.size(1)):
                         for part_idx in range(best_coef_values.size(2)):
                             val = best_coef_values[dataset_idx, block_idx, part_idx].item()
                             final_coef_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_final"] = val
            else:
                best_coef_values = best_coef_tensor
                for dataset_idx, val in enumerate(best_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    final_coef_log[f"coef_{dataset_name}_final"] = val.item()
        else:
            print("Warning: best_coef is None, skipping final coefficient logging.")

        wandb.log(final_coef_log)
        
        # Get the final test accuracy with the best model parameters
        # Re-fetch encoder in case the DDP module reference changed (though unlikely)
        image_encoder = ddp_model.module.image_encoder
        # No need to set coef again here, it was set above
            
        test_metrics = eval_single_dataset(image_encoder, target_dataset_test, args)
        final_test_acc = test_metrics["top1"]
        comp_acc[target_dataset_test] = final_test_acc
        
        # Log final results
        wandb.log({
            "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
            "final_val_accuracy": 100 * best_acc,
            "final_test_accuracy": 100 * final_test_acc,
        })
        
        print(f"Final test accuracy on {target_dataset_test}: {100 * final_test_acc:.2f}%")
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        # Create a new path for the test accuracy file with SLURM job ID
        test_acc_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset_test}_{subsample_str}_test_accuracy.json")
        
        # Prepare test accuracy data with additional information
        test_acc_data = {
            "target_dataset": str(target_dataset_test),  # Convert to string to ensure JSON serializable
            "number_of_used_datasets": len(args.datasets),  # Number of datasets used in this run
            "datasets_names": args.datasets,  # List of dataset names used
            "test_accuracy": float(final_test_acc),  # Final test accuracy value
            "test_metrics": test_metrics,
        }
        
        # Save or append the test accuracy data to the file
        if os.path.exists(test_acc_path):
            # Load existing data
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            # Convert to list if single dict
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            # Append new data
            existing_data.append(test_acc_data)
            
            # Save updated data
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            # Create new file with initial data
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}")
            
        # Log coefficient values
        if best_coef is not None:
            try:
                # Create a table for the coefficients with dataset names
                coef_data = []
                if best_coef is not None:
                    best_coef_tensor = best_coef.cpu()
                    
                    # Check if we are using the custom encoder with a mask
                    is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                    if is_custom_encoder:
                        best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                        # Log table with task, block, partition index, and value for active coefs
                        for task_idx in range(best_coef_tensor.size(0)):
                            task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                            for block_idx in range(best_coef_tensor.size(1)):
                                for part_idx in range(best_coef_tensor.size(2)):
                                    if best_mask[task_idx, block_idx, part_idx]:
                                        val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                        coef_data.append([task_idx, task_name, block_idx, part_idx, val])
                        
                        coef_table = wandb.Table(
                            columns=["task_idx", "dataset", "block_idx", "part_idx", "coefficient_value"],
                            data=coef_data
                        )
                        wandb.log({"final_coefficient_values_table": coef_table})
                        
                        # Log histogram of only the active coefficients
                        active_coef_values = best_coef_tensor[best_mask]
                        if active_coef_values.numel() > 0:
                           wandb.log({"final_active_coefficient_histogram": wandb.Histogram(active_coef_values.numpy())})
                        else:
                            print("No active coefficients found for histogram logging.")

                    elif args.blockwise_coef:
                        # Original blockwise table and histogram
                        best_coef_values = best_coef_tensor
                        for dataset_idx in range(best_coef_values.size(0)):
                            layer_coefs = best_coef_values[dataset_idx]
                            dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                            for layer_idx, val in enumerate(layer_coefs):
                                coef_data.append([layer_idx, dataset_idx, dataset_name, val.item()])
                        
                        coef_table = wandb.Table(
                            columns=["layer", "index", "dataset", "coefficient_value"], 
                            data=coef_data
                        )
                        wandb.log({"coefficient_values_table": coef_table})
                        
                        for layer_idx in range(best_coef_values.size(0)):
                             wandb.log({
                                 f"coefficient_histogram_layer_{layer_idx}": 
                                 wandb.Histogram(best_coef_values[layer_idx].numpy())
                             })
                    elif args.partition:
                        # Original partition table and histogram (adapt as needed)
                        best_coef_values = best_coef_tensor
                        # ... (Add table creation logic similar to blockwise/custom if needed)
                        print("WARNING: wandb.Table logging for original --partition mode not explicitly adapted. Logging histogram only.")
                        wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                    else:
                         # Original non-blockwise table and histogram
                         best_coef_values = best_coef_tensor
                         for dataset_idx, val in enumerate(best_coef_values.flatten()):
                             dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                             coef_data.append([dataset_idx, dataset_name, val.item()])
                        
                         coef_table = wandb.Table(
                             columns=["index", "dataset", "coefficient_value"], 
                             data=coef_data
                         )
                         wandb.log({"coefficient_values_table": coef_table})
                         wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                else:
                     print("Warning: best_coef is None, cannot log final coefficient table/histogram.")
            except Exception as e:
                print(f"Error logging coefficients to wandb: {e}")

     
    # Calculate experiment duration
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    # Format duration as hours:minutes:seconds
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}")
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)")
        
        # Log timing information to wandb
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    # Finish the wandb run before cleaning up distributed process
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    args = parse_arguments()
    
    
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("to ensure that experiment code states are tracked.")
    print("=" * 80)
    
    args.datasets = target_datasets
    # HACK: Some command line arguments are overwritten by defaults here.
    args.lr = 1e-1
    args.epochs = 1
    # We use gradient accumulation to simulate larger batch sizes if the model does not fit in memory.
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1

    args.print_every = 10
    
    # Set the seed to 0 for deterministic runs
    args.seed = 0

    # if args.seed is not None:
    #     args.save = f"checkpoints_{args.seed}/{args.model}"
    # else:
    args.save = args.save + f"{args.model}"
    if args.subsample is not None:
        pass
        # args.save += f"_{args.subsample*100:.0f}perc"

    # Print port information for reference
    print("=" * 80)
    print("NOTE: When running multiple instances, use different ports to avoid conflicts")
    print("Example: python src/learn_coef.py --port=29501")
    print("Available port range: 29500-29509")
    print("=" * 80)

    with open(os.path.join(args.save, "zeroshot_accuracies.json"), 'r') as f:
        args.zs_acc = json.load(f)
    

    

    
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
