import os
import time
import json
import torch
import torchvision
import sys

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

# Redirect wandb stderr to /dev/null before import
old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
    return 0 if p is None else gamma * torch.norm(x, p=p, dim=0).mean()

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For completely deterministic results, set the following flags
    # Note: This may slow down training
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")


def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    # Shuffle the port list to avoid always trying the same ports first
    port_list = list(port_list)  # Create a copy to avoid modifying the original
    random.shuffle(port_list)
    
    for port in port_list:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                # Double check by waiting a moment and trying again
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

def main(rank, args):
    args.rank = rank
    
    distributed_initialized = False
    
    try:
        # Use a different port to avoid "Address already in use" errors
        available_ports = list(range(29520, 29590))
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        
        if args.seed is not None:
            set_seed(args.seed)
        
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]

        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()
        
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx > args.end_index:
                print(f"Skipping iteration {idx+1} because end-idx is {args.end_idx}")
                continue

            try:
                args.datasets = original_datasets[:idx+1]

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning task vector coefficients on {args.target_dataset} with {args.model} using {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                print(f"Continuing with next iteration...")
                continue
            
    finally:
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    args.datasets = original_datasets

from torch import nn
from functorch import make_functional_with_buffers

class CustomWeightedImageEncoder(nn.Module):
    def __init__(self, model, task_vectors, total_trainable_params, num_task_vectors, num_blocks) -> None:
        """A custom wrapper class for task vector composition based on total trainable parameters.

        Args:
            model: Base image encoder model.
            task_vectors: List of task vectors.
            total_trainable_params: The exact number of coefficients to learn.
            num_task_vectors: Number of task vectors being used.
            num_blocks: Number of parameter blocks in the base model.
        """
        super().__init__()

        if num_task_vectors <= 0 or num_blocks <= 0:
            raise ValueError(f"num_task_vectors ({num_task_vectors}) and num_blocks ({num_blocks}) must be positive.")

        func, params, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = torch.nn.ParameterList(params)
        for p in self.params:
            p.requires_grad = False

        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir

        # Store task vector deltas (dparams)
        # Assuming task_vectors contain NonLinearTaskVector or similar with a .vector dict
        self.dparams = [[tv.vector[k] for k in tv.vector] for tv in task_vectors]

        self.num_task_vectors = num_task_vectors
        self.num_blocks = num_blocks
        self.total_trainable_params = total_trainable_params

        params_per_full_partition = self.num_task_vectors * self.num_blocks
        if params_per_full_partition == 0:
             # Handle edge case where there are no params to partition (e.g., empty model or task vectors)
             self.base_partition = 0
             self.remainder_params = 0
             self.max_partition = 1 # Avoid dimension size 0
        else:
            self.base_partition = self.total_trainable_params // params_per_full_partition
            self.remainder_params = self.total_trainable_params % params_per_full_partition
            self.max_partition = self.base_partition + (1 if self.remainder_params > 0 else 0)
            if self.max_partition == 0:
                self.max_partition = 1

        print(f"[CustomEncoder] Total Params: {self.total_trainable_params}, Tasks: {self.num_task_vectors}, Blocks: {self.num_blocks}")
        print(f"[CustomEncoder] Base Partition: {self.base_partition}, Remainder Params: {self.remainder_params}, Max Partition Dim: {self.max_partition}")

        self.coef = torch.nn.Parameter(torch.zeros(self.num_task_vectors, self.num_blocks, self.max_partition))

        # This mask determines WHICH coef entries are actually used and trained
        trainable_coef_mask = torch.zeros_like(self.coef, dtype=torch.bool, requires_grad=False)

        if self.base_partition > 0:
            trainable_coef_mask[:, :, :self.base_partition] = True

        current_remainder = self.remainder_params
        if current_remainder > 0 and self.base_partition < self.max_partition:
            print(f"[CustomEncoder] Distributing {current_remainder} remainder coefficients...")
            coef_idx_to_distribute = self.base_partition # The index within the max_partition dimension of coef
            assigned_count = 0
            # Iterate forwards through blocks
            for block_idx in range(self.num_blocks):
                # Alternate between task vectors
                for task_idx in range(self.num_task_vectors):
                    if current_remainder > 0:
                        trainable_coef_mask[task_idx, block_idx, coef_idx_to_distribute] = True
                        current_remainder -= 1
                        assigned_count += 1
                    else:
                        break
                if current_remainder <= 0:
                    break
            print(f"[CustomEncoder] Assigned {assigned_count} remainder coefficients.")

        self.register_buffer('trainable_coef_mask', trainable_coef_mask.float())

        # We store these as buffers, associated with parameter shapes
        print(f"[CustomEncoder] Creating mask_mats buffers for {self.num_blocks} blocks...")
        for i, p in enumerate(self.params):
            p_shape = tuple(p.shape)
            if p.ndim == 0:
                print(f"    - Skipping scalar parameter at index {i}")
                continue
            print(f"    - Processing shape {p_shape} (Block {i})")
            assignments = torch.randint(0, self.max_partition, p_shape, device='cpu')
            mask_mat = torch.nn.functional.one_hot(assignments, num_classes=self.max_partition).permute(-1, *range(p.ndim)).float()
            
            buffer_name = f"mask_mat_block_{i}"
            self.register_buffer(buffer_name, mask_mat)
            print(f"      - Created buffer '{buffer_name}' with shape: {mask_mat.shape}")
        print(f"[CustomEncoder] Finished creating mask_mats buffers for {self.num_blocks} unique blocks.")
        # Ensure buffers are moved to the correct device later by _apply

    def _apply(self, fn):
        new_self = super()._apply(fn=fn)
        new_self.dparams = [[fn(x) if isinstance(x, torch.Tensor) else x for x in tv] for tv in new_self.dparams]
        return new_self

    def train(self, mode=True):
        super().train(mode)

    def forward(self, x) -> torch.Tensor:
        effective_coefs = self.coef * self.trainable_coef_mask # Shape: (T, B, P)

        final_dparams = []
        for block_idx, p_orig in enumerate(self.params):
            if p_orig.ndim == 0:
                final_dparams.append(torch.tensor(0.0, device=p_orig.device, dtype=p_orig.dtype))
                continue

            buffer_name = f"mask_mat_block_{block_idx}"
            if hasattr(self, buffer_name):
                block_mask_mat = getattr(self, buffer_name) # Shape: (P, *p_shape)
            else:
                # This should not happen if __init__ processed all blocks correctly
                raise RuntimeError(f"Mask matrix buffer '{buffer_name}' not found for block index {block_idx}")

            final_delta_b = torch.zeros_like(p_orig)
            
            block_effective_coefs = effective_coefs[:, block_idx, :] # Shape: (T, P)
            
            for task_idx in range(self.num_task_vectors):
                task_delta = self.dparams[task_idx][block_idx]
                if not isinstance(task_delta, torch.Tensor):
                    continue

                task_block_coefs = block_effective_coefs[task_idx, :] # Shape: (P)
                
                coef_view = task_block_coefs.view(-1, *([1] * p_orig.ndim))
                weighted_mask_for_task = (coef_view * block_mask_mat).sum(dim=0)
                final_delta_b += weighted_mask_for_task * task_delta

            final_dparams.append(final_delta_b)

        new_params = [(dp + p) if isinstance(dp, torch.Tensor) else p for dp, p in zip(final_dparams, self.params)]

        return self.func(new_params, self.buffer, x)

def train(task_vectors, args):
    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    if args.seed is not None:
        set_seed(args.seed + args.rank)  # Add rank to avoid identical random values across processes

    target_dataset = args.target_dataset
    
    if is_main_process():
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        wandb.login(key="your-wandb-key")
        wandb.init(
            project="axis",
            entity="example-owner",
            name=f"{args.model}_{target_dataset}_{subsample_str}-blkcoef-{args.blockwise_coef}-e{args.epochs}-nbdts{len(args.datasets)}-part{args.partition}",
            config={
                # Basic experiment config
                "type": "learn_coef",
                "kind": "baseline-20pct",
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "blockwise_coef": True,
                "subsample": args.subsample,
                
                # System information
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                # Runtime info
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                # Job details
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                # Training details
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,
                
                # Model details
                "port_selection_method": "auto_find_available",
            }
        )

        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
        print("-" * 50 + "\n")

    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    assert args.finetuning_mode in [
        "linear",
        "standard",
    ], "Only linear and standard fine-tuning are supported."

    linearized_finetuning = args.finetuning_mode == "linear"
    if linearized_finetuning:
        print("Using linearized fine-tuning.")

    orig_dataset = target_dataset.replace("Val", "")
    pool = [
        "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
        "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
        "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
    ]
    dataset_names = [ds for ds in pool if ds != orig_dataset]
    
    task_vectors = [v for k, v in task_vectors.items() if orig_dataset != k]
    
    # Store the mapping of index to dataset name for tracking in wandb
    coef_dataset_mapping = {i: dataset_names[i] for i in range(len(dataset_names))}
    if is_main_process():
        print("Coefficient to dataset mapping:")
        for idx, ds_name in coef_dataset_mapping.items():
            print(f"Coefficient {idx}: {ds_name}")
    
    if args.finetuning_mode == "linear":
        print("WARNING: --total-trainable-params currently only supported for standard finetuning mode. Using original WeightedLinearizedModel.")
        image_encoder = LinearizedImageEncoder(args, keep_lang=False)
        image_encoder.model = WeightedLinearizedModel(
            image_encoder.model, task_vectors, blockwise=args.blockwise_coef, partition=args.partition
        )
    elif args.total_trainable_params is not None:
        print(f"Using CustomWeightedImageEncoder with total_trainable_params = {args.total_trainable_params}")
        temp_base_encoder = ImageEncoder(args)
        _, temp_params, _ = make_functional_with_buffers(temp_base_encoder)
        num_blocks = len(temp_params)
        num_task_vectors = len(task_vectors)
        del temp_base_encoder, temp_params
        
        image_encoder = CustomWeightedImageEncoder(
            model=ImageEncoder(args), # Pass a new instance
            task_vectors=task_vectors,
            total_trainable_params=args.total_trainable_params,
            num_task_vectors=num_task_vectors,
            num_blocks=num_blocks
        )
        print(f"Created CustomWeightedImageEncoder with {num_task_vectors} tasks, {num_blocks} blocks.")

        if is_main_process():
            print("\nParameters of the CustomWeightedImageEncoder (Trainable Coefs):")
            param_counter = 1
            total_params_in_coef = 0
            for name, param in image_encoder.named_parameters():
                if param.requires_grad:
                    print(f"{param_counter}. {name} (Shape: {list(param.shape)}, Type: {param.dtype}) - Trainable")
                    total_params_in_coef += param.numel()
                    param_counter += 1
            print(f"Total elements in trainable 'coef' parameter: {total_params_in_coef}")
            if hasattr(image_encoder, 'trainable_coef_mask'):
                num_masked_trainable = image_encoder.trainable_coef_mask.sum().item()
                print(f"Number of trainable coefficients according to mask: {int(num_masked_trainable)}")
                if int(num_masked_trainable) != args.total_trainable_params:
                     print(f"WARNING: Masked trainable count ({int(num_masked_trainable)}) does not match requested total_trainable_params ({args.total_trainable_params})!")
            print("-" * 30)


    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])


    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=4,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    # Printing loss between four and ten times an epoch
    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=False,
        output_device=args.rank,
    )

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    # Do not use warm up
    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    
    if args.total_trainable_params is not None:
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.json")
        coef_param = ddp_model.module.image_encoder.coef
        coef_mask = ddp_model.module.image_encoder.trainable_coef_mask
        print("Using CustomWeightedImageEncoder - paths set.")
        print(f"Coef parameter shape: {coef_param.shape}, Mask shape: {coef_mask.shape}")
    elif linearized_finetuning:
        head_path = os.path.join(ckpdir, f"learned_linear_composition_{subsample_str}.pt")
        log_path = os.path.join(args.save, f"learned_linear_composition_{subsample_str}.json")
        coef_param = ddp_model.module.image_encoder.model.coef # Access coef in WeightedLinearizedModel
        coef_mask = None # No mask for linearized or original weighted encoder
        print("Using Linearized fine-tuning - paths set.")
    else:
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.json")
        coef_param = ddp_model.module.image_encoder.coef # Access coef in WeightedImageEncoder
        coef_mask = None # No mask for linearized or original weighted encoder
        print("Using standard WeightedImageEncoder - paths set.")

    print("head_path:", head_path)
    print("log_path:", log_path)
    print("Initial coef param shape:", coef_param.shape)

    scaler = GradScaler()
    if is_main_process():
        print(f"=> Zero-shot accuracy on {target_dataset}:\t{100*args.zs_acc[target_dataset]:.2f}%.")
        wandb.log({
            "target_dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
            "zero_shot_accuracy": 100 * args.zs_acc[target_dataset]
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    best_coef = coef_param.data.clone()
    best_acc = args.zs_acc[target_dataset]


    total_params = sum(p.numel() for p in ddp_model.parameters())
    
    # Adjust calculation of trainable params if using CustomWeightedImageEncoder
    is_custom_encoder = args.total_trainable_params is not None and hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
    if is_custom_encoder:
        # Sum params excluding the coef tensor, then add the masked count
        params_excluding_coef = [p for name, p in ddp_model.named_parameters() if p.requires_grad and 'image_encoder.coef' not in name]
        num_trainable_params_other = sum(p.numel() for p in params_excluding_coef)
        num_trainable_params_coef = int(ddp_model.module.image_encoder.trainable_coef_mask.sum().item())
        num_trainable_params = num_trainable_params_other + num_trainable_params_coef
        print(f"[Param Stats - Custom] Other trainable: {num_trainable_params_other}, Masked coef: {num_trainable_params_coef}, Total effective trainable: {num_trainable_params}")
    else:
        num_trainable_params = sum(p.numel() for p in ddp_model.parameters() if p.requires_grad)
        print(f"[Param Stats - Standard] Total trainable: {num_trainable_params}")

    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {num_trainable_params:,}")
        print(f"Percentage trainable: {100 * num_trainable_params / total_params:.2f}%\n")
        wandb.config.update({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        }, allow_val_change=True)
        
        wandb.log({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        })

    for epoch in range(args.epochs):
        epoch_loss = 0.0
        epoch_steps = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)

                reg = lp_reg(coef_param, args.lp_reg)
                loss = loss + reg
                loss = loss / args.num_grad_accumulation
                
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"           # noqa: E501
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",   # noqa: E501
                    flush=True,
                )
                
                batch_log = {
                    "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "regularization": reg.item() if reg != 0 else 0,
                }
                
                # Use .data to get tensor without grad_fn
                current_coef_tensor = ddp_model.module.image_encoder.coef.data
                
                # Check if we are using the custom encoder with a mask
                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                    # Iterate through the coef tensor and log only where mask is True
                    for task_idx in range(current_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(current_coef_tensor.size(1)):
                            for part_idx in range(current_coef_tensor.size(2)):
                                if current_mask[task_idx, block_idx, part_idx]:
                                    coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                    batch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}"] = coef_value
                elif args.blockwise_coef:
                    # Original blockwise logging (assuming 2D coef)
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx in range(current_coef_values.size(0)):
                        layer_coefs = current_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                             batch_log[f"coef_layer{layer_idx}_{dataset_name}"] = coef_value.item()
                elif args.partition: # Original partition logging (assuming 3D coef but different structure)
                     current_coef_values = current_coef_tensor.cpu()
                     # Assuming original partition structure: (tasks, blocks, partitions)
                     for dataset_idx in range(current_coef_values.size(0)):
                         dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                         for block_idx in range(current_coef_values.size(1)):
                             for part_idx in range(current_coef_values.size(2)):
                                 coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                                 batch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}"] = coef_value
                else:
                    # Original non-blockwise logging (assuming 1D coef)
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        batch_log[f"coef_{dataset_name}"] = coef_value.item()
                
                wandb.log(batch_log)
        
        if is_main_process():
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
            })
            
            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            epoch_log = {
                "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            # Use .data to get tensor without grad_fn
            current_coef_tensor = ddp_model.module.image_encoder.coef.data
            
            # Check if we are using the custom encoder with a mask
            is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
            if is_custom_encoder:
                current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                # Iterate and log only masked values
                for task_idx in range(current_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(current_coef_tensor.size(1)):
                        for part_idx in range(current_coef_tensor.size(2)):
                            if current_mask[task_idx, block_idx, part_idx]:
                                coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                epoch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            elif args.blockwise_coef:
                 # Original blockwise logging (assuming 2D coef)
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     layer_coefs = current_coef_values[dataset_idx]
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for layer_idx, coef_value in enumerate(layer_coefs):
                         epoch_log[f"coef_layer{layer_idx}_{dataset_name}_epoch"] = coef_value.item()
            elif args.partition: # Original partition logging
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(current_coef_values.size(1)):
                         for part_idx in range(current_coef_values.size(2)):
                            coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                            epoch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            else:
                # Original non-blockwise logging (assuming 1D coef)
                current_coef_values = current_coef_tensor.cpu()
                for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    epoch_log[f"coef_{dataset_name}_epoch"] = coef_value.item()
                
            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%")
            
            if val_acc > best_acc:
                best_acc = val_acc
                best_coef = coef_param.data.clone()
                torch.save(best_coef, head_path)
                
                best_log = {
                    "target_dataset": target_dataset,
                    "best_val_accuracy": 100 * best_acc,
                    "best_epoch": epoch,
                }
                
                best_coef_tensor = best_coef.cpu() # best_coef is already cloned .data

                # Check if we are using the custom encoder with a mask
                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    # We need the mask corresponding to the state when best_coef was saved
                    # Assuming the mask doesn't change during training (it's a buffer)
                    best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                    for task_idx in range(best_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(best_coef_tensor.size(1)):
                            for part_idx in range(best_coef_tensor.size(2)):
                                if best_mask[task_idx, block_idx, part_idx]:
                                    coef_value = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                    best_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                elif args.blockwise_coef:
                    # Original blockwise
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        layer_coefs = best_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                            best_log[f"coef_layer{layer_idx}_{dataset_name}_best"] = coef_value.item()
                elif args.partition:
                    # Original partition
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for block_idx in range(best_coef_values.size(1)):
                            for part_idx in range(best_coef_values.size(2)):
                                coef_value = best_coef_values[dataset_idx, block_idx, part_idx].item()
                                best_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                else:
                    # Original non-blockwise
                    best_coef_values = best_coef_tensor
                    for dataset_idx, coef_value in enumerate(best_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        best_log[f"coef_{dataset_name}_best"] = coef_value.item()
                
                wandb.log(best_log)

    if is_main_process():
        comp_acc[target_dataset] = best_acc
        target_dataset_test = target_dataset.replace("Val", "")
        image_encoder = ddp_model.module.image_encoder
        
        is_custom_encoder = hasattr(image_encoder, 'trainable_coef_mask')
        if is_custom_encoder:
            # For custom encoder, best_coef holds the entire 3D tensor
            # Parameter assignment is handled internally via Parameter(best_coef)
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
                print("Loaded best_coef (3D tensor) into CustomWeightedImageEncoder")
            else:
                print("Warning: best_coef is None, cannot load into CustomWeightedImageEncoder")
        elif linearized_finetuning:
            # Original linearized logic
            if best_coef is not None:
                 image_encoder.model.coef = torch.nn.Parameter(best_coef)
            else:
                print("Warning: best_coef is None, cannot load into WeightedLinearizedModel")
        else:
            # Original standard logic
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
            else:
                 print("Warning: best_coef is None, cannot load into WeightedImageEncoder")

        print("best_coef shape after loading:", image_encoder.coef.shape if best_coef is not None else "N/A")
        
        final_coef_log = {}
        if best_coef is not None:
            best_coef_tensor = best_coef.cpu()
            if is_custom_encoder:
                best_mask = image_encoder.trainable_coef_mask.data.bool().cpu()
                for task_idx in range(best_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(best_coef_tensor.size(1)):
                        for part_idx in range(best_coef_tensor.size(2)):
                            if best_mask[task_idx, block_idx, part_idx]:
                                val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                final_coef_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_final"] = val
            elif args.blockwise_coef:
                best_coef_values = best_coef_tensor
                for dataset_idx in range(best_coef_values.size(0)):
                    layer_coefs = best_coef_values[dataset_idx]
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    for layer_idx, val in enumerate(layer_coefs):
                        final_coef_log[f"coef_layer{layer_idx}_{dataset_name}_final"] = val.item()
            elif args.partition:
                 best_coef_values = best_coef_tensor
                 for dataset_idx in range(best_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(best_coef_values.size(1)):
                         for part_idx in range(best_coef_values.size(2)):
                             val = best_coef_values[dataset_idx, block_idx, part_idx].item()
                             final_coef_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_final"] = val
            else:
                best_coef_values = best_coef_tensor
                for dataset_idx, val in enumerate(best_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    final_coef_log[f"coef_{dataset_name}_final"] = val.item()
        else:
            print("Warning: best_coef is None, skipping final coefficient logging.")

        wandb.log(final_coef_log)

        image_encoder = ddp_model.module.image_encoder
            
        test_metrics = eval_single_dataset(image_encoder, target_dataset_test, args)
        final_test_acc = test_metrics["top1"]
        comp_acc[target_dataset_test] = final_test_acc
        
        wandb.log({
            "dataset": str(target_dataset),  # Convert to string to ensure JSON serializable
            "final_val_accuracy": 100 * best_acc,
            "final_test_accuracy": 100 * final_test_acc,
        })
        
        print(f"Final test accuracy on {target_dataset_test}: {100 * final_test_acc:.2f}%")
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        test_acc_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset_test}_{subsample_str}_test_accuracy.json")
        
        test_acc_data = {
            "target_dataset": str(target_dataset_test),  # Convert to string to ensure JSON serializable
            "number_of_used_datasets": len(args.datasets),  # Number of datasets used in this run
            "datasets_names": args.datasets,  # List of dataset names used
            "test_accuracy": float(final_test_acc),  # Final test accuracy value
            "test_metrics": test_metrics,
        }
        
        if os.path.exists(test_acc_path):
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            existing_data.append(test_acc_data)
            
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}")
            
        if best_coef is not None:
            try:
                coef_data = []
                if best_coef is not None:
                    best_coef_tensor = best_coef.cpu()
                    
                    # Check if we are using the custom encoder with a mask
                    is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                    if is_custom_encoder:
                        best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                        # Log table with task, block, partition index, and value for active coefs
                        for task_idx in range(best_coef_tensor.size(0)):
                            task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                            for block_idx in range(best_coef_tensor.size(1)):
                                for part_idx in range(best_coef_tensor.size(2)):
                                    if best_mask[task_idx, block_idx, part_idx]:
                                        val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                        coef_data.append([task_idx, task_name, block_idx, part_idx, val])
                        
                        coef_table = wandb.Table(
                            columns=["task_idx", "dataset", "block_idx", "part_idx", "coefficient_value"],
                            data=coef_data
                        )
                        wandb.log({"final_coefficient_values_table": coef_table})
                        
                        active_coef_values = best_coef_tensor[best_mask]
                        if active_coef_values.numel() > 0:
                           wandb.log({"final_active_coefficient_histogram": wandb.Histogram(active_coef_values.numpy())})
                        else:
                            print("No active coefficients found for histogram logging.")

                    elif args.blockwise_coef:
                        best_coef_values = best_coef_tensor
                        for dataset_idx in range(best_coef_values.size(0)):
                            layer_coefs = best_coef_values[dataset_idx]
                            dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                            for layer_idx, val in enumerate(layer_coefs):
                                coef_data.append([layer_idx, dataset_idx, dataset_name, val.item()])
                        
                        coef_table = wandb.Table(
                            columns=["layer", "index", "dataset", "coefficient_value"], 
                            data=coef_data
                        )
                        wandb.log({"coefficient_values_table": coef_table})
                        
                        for layer_idx in range(best_coef_values.size(0)):
                             wandb.log({
                                 f"coefficient_histogram_layer_{layer_idx}": 
                                 wandb.Histogram(best_coef_values[layer_idx].numpy())
                             })
                    elif args.partition:
                        best_coef_values = best_coef_tensor
                        print("WARNING: wandb.Table logging for original --partition mode not explicitly adapted. Logging histogram only.")
                        wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                    else:
                         # Original non-blockwise table and histogram
                         best_coef_values = best_coef_tensor
                         for dataset_idx, val in enumerate(best_coef_values.flatten()):
                             dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                             coef_data.append([dataset_idx, dataset_name, val.item()])
                        
                         coef_table = wandb.Table(
                             columns=["index", "dataset", "coefficient_value"], 
                             data=coef_data
                         )
                         wandb.log({"coefficient_values_table": coef_table})
                         wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                else:
                     print("Warning: best_coef is None, cannot log final coefficient table/histogram.")
            except Exception as e:
                print(f"Error logging coefficients to wandb: {e}")

    
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}")
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)")
        
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    # Finish the wandb run before cleaning up distributed process
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    args = parse_arguments()
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("to ensure that experiment code states are tracked.")
    print("=" * 80)
    
    args.datasets = target_datasets
    # HACK: Some command line arguments are overwritten by defaults here.
    args.lr = 1e-1
    args.epochs = 10
    # We use gradient accumulation to simulate larger batch sizes if the model does not fit in memory.
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1

    args.print_every = 10
    
    args.seed = 0

    args.save = args.save + f"{args.model}"
    if args.subsample is not None:
        pass

    print("=" * 80)
    print("NOTE: When running multiple instances, use different ports to avoid conflicts")
    print("Example: python src/learn_coef.py --port=29501")
    print("Available port range: 29500-29509")
    print("=" * 80)

    with open(os.path.join(args.save, "zeroshot_accuracies.json"), 'r') as f:
        args.zs_acc = json.load(f)
    
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
