import os
import time
import json
import torch
import torchvision
import sys

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

SKIP_PARAMS = [
    "positional_embedding", "visual.positional_embedding",
    "visual.proj", "token_embedding", "visual.proj", "text_projection", "token_embedding.weight",
    "model.positional_embedding", "model.visual.positional_embedding", "model.token_embedding.weight",
    "model.visual.proj", "model.token_embedding", "model.visual.proj", "model.text_projection"
]

def prunned_task_vector(task_vector, skip_params, prunned_level=0.98):
    """
    Prunes a task vector using magnitude-based pruning.
    For each 2D layer not in skip_params, a certain percentage of weights
    with the smallest magnitudes are set to zero.
    Args:
        task_vector: The TaskVector object to prune.
        skip_params (list): A list of parameter names to ignore.
        prunned_level (float): The fraction of weights to prune (e.g., 0.98 for 98%).
    Returns:
        The pruned TaskVector object.
    """
    if is_main_process():
        print(f"Pruning a task vector with pruning level: {prunned_level}")

    with torch.no_grad():
        for key, delta in task_vector.vector.items():
            if key in skip_params or delta.dim() != 2:
                continue

            if delta.is_floating_point():
                num_elements = delta.numel()
                num_to_prune = int(prunned_level * num_elements)
                
                if num_to_prune <= 0:
                    continue
                if num_to_prune >= num_elements:
                    delta.zero_()
                    continue

                threshold = torch.kthvalue(delta.flatten().abs(), num_to_prune).values
                
                mask = delta.abs() <= threshold
                delta[mask] = 0.0
    
    if is_main_process():
        print("Task vector pruning complete.")
        
    return task_vector

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
    return 0 if p is None else gamma * torch.norm(x, p=p, dim=0).mean()

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")


def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    port_list = list(port_list)
    random.shuffle(port_list)
    
    for port in port_list:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

def main(rank, args):
    args.rank = rank
    
    distributed_initialized = False
    
    try:
        available_ports = list(range(29520, 29590))
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        
        if args.seed is not None:
            set_seed(args.seed)
        
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]
        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()
        
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx >= args.end_index:
                print(f"Skipping iteration {idx+1} because end-idx is {args.end_index}")
                continue

            try:
                args.datasets = original_datasets[:idx+1]

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning task vector coefficients on {args.target_dataset} with {args.model} using {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                print(f"Continuing with next iteration...")
                continue
            
    finally:
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    args.datasets = original_datasets

from torch import nn
from functorch import make_functional_with_buffers

class CustomWeightedImageEncoder(nn.Module):
    def __init__(self, model, task_vectors, total_trainable_params, num_task_vectors, num_blocks) -> None:
        """A custom wrapper class for task vector composition based on total trainable parameters.

        Args:
            model: Base image encoder model.
            task_vectors: List of task vectors.
            total_trainable_params: The exact number of coefficients to learn.
            num_task_vectors: Number of task vectors being used.
            num_blocks: Number of parameter blocks in the base model.
        """
        super().__init__()

        if num_task_vectors <= 0 or num_blocks <= 0:
            raise ValueError(f"num_task_vectors ({num_task_vectors}) and num_blocks ({num_blocks}) must be positive.")

        func, params, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = torch.nn.ParameterList(params)
        for p in self.params:
            p.requires_grad = False

        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir

        self.dparams = [[tv.vector[k] for k in tv.vector] for tv in task_vectors]

        self.num_task_vectors = num_task_vectors
        self.num_blocks = num_blocks
        self.total_trainable_params = total_trainable_params

        params_per_full_partition = self.num_task_vectors * self.num_blocks
        if params_per_full_partition == 0:
             self.base_partition = 0
             self.remainder_params = 0
             self.max_partition = 1
        else:
            self.base_partition = self.total_trainable_params // params_per_full_partition
            self.remainder_params = self.total_trainable_params % params_per_full_partition
            self.max_partition = self.base_partition + (1 if self.remainder_params > 0 else 0)
            if self.max_partition == 0:
                self.max_partition = 1

        print(f"[CustomEncoder] Total Params: {self.total_trainable_params}, Tasks: {self.num_task_vectors}, Blocks: {self.num_blocks}")
        print(f"[CustomEncoder] Base Partition: {self.base_partition}, Remainder Params: {self.remainder_params}, Max Partition Dim: {self.max_partition}")

        self.coef = torch.nn.Parameter(torch.zeros(self.num_task_vectors, self.num_blocks, self.max_partition))

        trainable_coef_mask = torch.zeros_like(self.coef, dtype=torch.bool, requires_grad=False)

        if self.base_partition > 0:
            trainable_coef_mask[:, :, :self.base_partition] = True

        current_remainder = self.remainder_params
        if current_remainder > 0 and self.base_partition < self.max_partition:
            print(f"[CustomEncoder] Distributing {current_remainder} remainder coefficients...")
            coef_idx_to_distribute = self.base_partition
            assigned_count = 0
            for block_idx in range(self.num_blocks):
                for task_idx in range(self.num_task_vectors):
                    if current_remainder > 0:
                        trainable_coef_mask[task_idx, block_idx, coef_idx_to_distribute] = True
                        current_remainder -= 1
                        assigned_count += 1
                    else:
                        break
                if current_remainder <= 0:
                    break
            print(f"[CustomEncoder] Assigned {assigned_count} remainder coefficients.")

        self.register_buffer('trainable_coef_mask', trainable_coef_mask.float())

        print(f"[CustomEncoder] Creating mask_mats buffers for {self.max_partition} partitions...")
        self._mask_mat_buffer_shapes = {}
        processed_shapes = set()
        for i, p in enumerate(self.params):
            p_shape = tuple(p.shape)
            if p_shape not in processed_shapes:
                if p.ndim == 0:
                    print(f"    - Skipping scalar parameter at index {i}")
                    continue
                print(f"    - Processing shape {p_shape} (Block {i})")
                assignments = torch.randint(0, self.max_partition, p_shape, device='cpu')
                mask_mat = torch.nn.functional.one_hot(assignments, num_classes=self.max_partition).permute(-1, *range(p.ndim)).float()
                
                buffer_name = f"mask_mat_{'_'.join(map(str, p_shape))}"
                self.register_buffer(buffer_name, mask_mat)
                self._mask_mat_buffer_shapes[p_shape] = buffer_name
                processed_shapes.add(p_shape)
                print(f"      - Created buffer '{buffer_name}' with shape: {mask_mat.shape}")
        print(f"[CustomEncoder] Finished creating mask_mats buffers for {len(processed_shapes)} unique shapes.")

    def _apply(self, fn):
        new_self = super()._apply(fn=fn)
        new_self.dparams = [[fn(x) if isinstance(x, torch.Tensor) else x for x in tv] for tv in new_self.dparams]
        for shape, buffer_name in new_self._mask_mat_buffer_shapes.items():
            if hasattr(new_self, buffer_name):
                 current_buffer = getattr(new_self, buffer_name)
                 if isinstance(current_buffer, torch.Tensor):
                     setattr(new_self, buffer_name, fn(current_buffer))
            else:
                print(f"Warning: Buffer {buffer_name} not found during _apply")
        if hasattr(new_self, 'trainable_coef_mask') and isinstance(new_self.trainable_coef_mask, torch.Tensor):
            new_self.trainable_coef_mask = fn(new_self.trainable_coef_mask)

        return new_self

    def train(self, mode=True):
        super().train(mode)

    def forward(self, x) -> torch.Tensor:
        effective_coefs = self.coef * self.trainable_coef_mask

        final_dparams = []
        for block_idx, p_orig in enumerate(self.params):
            p_shape = tuple(p_orig.shape)
            
            if p_shape in self._mask_mat_buffer_shapes:
                buffer_name = self._mask_mat_buffer_shapes[p_shape]
                block_mask_mat = getattr(self, buffer_name)
            elif p_orig.ndim == 0:
                 final_dparams.append(torch.tensor(0.0, device=p_orig.device, dtype=p_orig.dtype))
                 continue
            else:
                 raise RuntimeError(f"Mask matrix buffer not found for parameter shape {p_shape} at block index {block_idx}")

            block_task_deltas = [self.dparams[task_idx][block_idx] for task_idx in range(self.num_task_vectors)]
            
            valid_deltas = [delta for delta in block_task_deltas if isinstance(delta, torch.Tensor)]
            if not valid_deltas:
                 final_dparams.append(torch.zeros_like(p_orig))
                 continue
            stacked_deltas = torch.stack(valid_deltas, dim=0)

            block_effective_coefs = effective_coefs[:, block_idx, :]
            
            T = stacked_deltas.shape[0]
            P = block_mask_mat.shape[0]
            shape_dims = stacked_deltas.shape[1:]
            view_T_P_1s = (T, P) + (1,) * len(shape_dims)
            view_1_P_shape = (1, P) + shape_dims
            view_T_1_shape = (T, 1) + shape_dims

            term = block_effective_coefs.view(view_T_P_1s) * \
                   block_mask_mat.view(view_1_P_shape) * \
                   stacked_deltas.view(view_T_1_shape)
            
            final_delta_b = term.sum(dim=(0, 1))
            
            final_dparams.append(final_delta_b)

        new_params = [(dp + p) if isinstance(dp, torch.Tensor) else p for dp, p in zip(final_dparams, self.params)]

        return self.func(new_params, self.buffer, x)

def train(task_vectors, args):

    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    if args.seed is not None:
        set_seed(args.seed + args.rank)

    target_dataset = args.target_dataset
    
    
    if is_main_process():
        gpu_info = {}
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        wandb.login(key="your-wandb-key")
        wandb.init(
            project="atlas2",
            entity="example-owner",
            name=f"{args.model}_{target_dataset}_{subsample_str}-blkcoef-{args.blockwise_coef}-e{args.epochs}-nbdts{len(args.datasets)}-part{args.partition}",
            config={
                "type": "learn_coef",
                "kind": "baseline-10pct-prunned3",
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "blockwise_coef": True,
                "subsample": args.subsample,
                
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,
                
                "port_selection_method": "auto_find_available",
            }
        )

        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
        print("-" * 50 + "\n")

    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    assert args.finetuning_mode in [
        "linear",
        "standard",
    ], "Only linear and standard fine-tuning are supported."

    linearized_finetuning = args.finetuning_mode == "linear"
    if linearized_finetuning:
        print("Using linearized fine-tuning.")

    orig_dataset = target_dataset.replace("Val", "")
    pool = [
        "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
        "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
        "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
    ]
    dataset_names = [ds for ds in pool if ds != orig_dataset]
    
    source_task_vectors_with_names = [(k, v) for k, v in task_vectors.items() if orig_dataset != k]
    num_source_tasks = len(source_task_vectors_with_names)

    if is_main_process():
        print("\n" + "="*50)
        print(f"Pruning all {num_source_tasks} source task vectors.")
        print("="*50 + "\n")

    task_vectors = []
    for name, task_vector in source_task_vectors_with_names:
        if is_main_process():
            print(f"Pruning task vector for dataset: {name}")
        pruned_vector = prunned_task_vector(task_vector, SKIP_PARAMS, prunned_level=0.98)
        task_vectors.append(pruned_vector)

    coef_dataset_mapping = {i: dataset_names[i] for i in range(len(dataset_names))}
    if is_main_process():
        print("Coefficient to dataset mapping:")
        for idx, ds_name in coef_dataset_mapping.items():
            print(f"Coefficient {idx}: {ds_name}")
    
    if args.finetuning_mode == "linear":
        print("WARNING: --total-trainable-params currently only supported for standard finetuning mode. Using original WeightedLinearizedModel.")
        image_encoder = LinearizedImageEncoder(args, keep_lang=False)
        image_encoder.model = WeightedLinearizedModel(
            image_encoder.model, task_vectors, blockwise=args.blockwise_coef, partition=args.partition
        )
    elif args.total_trainable_params is not None:
        print(f"Using CustomWeightedImageEncoder with total_trainable_params = {args.total_trainable_params}")
        temp_base_encoder = ImageEncoder(args)
        _, temp_params, _ = make_functional_with_buffers(temp_base_encoder)
        num_blocks = len(temp_params)
        num_task_vectors = len(task_vectors)
        del temp_base_encoder, temp_params
        
        image_encoder = CustomWeightedImageEncoder(
            model=ImageEncoder(args),
            task_vectors=task_vectors,
            total_trainable_params=args.total_trainable_params,
            num_task_vectors=num_task_vectors,
            num_blocks=num_blocks
        )
        print(f"Created CustomWeightedImageEncoder with {num_task_vectors} tasks, {num_blocks} blocks.")

        if is_main_process():
            print("\nParameters of the CustomWeightedImageEncoder (Trainable Coefs):")
            param_counter = 1
            total_params_in_coef = 0
            for name, param in image_encoder.named_parameters():
                if param.requires_grad:
                    print(f"{param_counter}. {name} (Shape: {list(param.shape)}, Type: {param.dtype}) - Trainable")
                    total_params_in_coef += param.numel()
                    param_counter += 1
            print(f"Total elements in trainable 'coef' parameter: {total_params_in_coef}")
            if hasattr(image_encoder, 'trainable_coef_mask'):
                num_masked_trainable = image_encoder.trainable_coef_mask.sum().item()
                print(f"Number of trainable coefficients according to mask: {int(num_masked_trainable)}")
                if int(num_masked_trainable) != args.total_trainable_params:
                     print(f"WARNING: Masked trainable count ({int(num_masked_trainable)}) does not match requested total_trainable_params ({args.total_trainable_params})!")
            print("-" * 30)


    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])

    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=4,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=False,
        output_device=args.rank,
    )

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    
    if args.total_trainable_params is not None:
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.json")
        coef_param = ddp_model.module.image_encoder.coef
        coef_mask = ddp_model.module.image_encoder.trainable_coef_mask
        print("Using CustomWeightedImageEncoder - paths set.")
        print(f"Coef parameter shape: {coef_param.shape}, Mask shape: {coef_mask.shape}")
    elif linearized_finetuning:
        head_path = os.path.join(ckpdir, f"learned_linear_composition_{subsample_str}.pt")
        log_path = os.path.join(args.save, f"learned_linear_composition_{subsample_str}.json")
        coef_param = ddp_model.module.image_encoder.model.coef
        coef_mask = None
        print("Using Linearized fine-tuning - paths set.")
    else:
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.json")
        coef_param = ddp_model.module.image_encoder.coef
        coef_mask = None
        print("Using standard WeightedImageEncoder - paths set.")

    print("head_path:", head_path)
    print("log_path:", log_path)
    print("Initial coef param shape:", coef_param.shape)

    scaler = GradScaler()
    if is_main_process():
        print(f"=> Zero-shot accuracy on {target_dataset}:\t{100*args.zs_acc[target_dataset]:.2f}%.")
        wandb.log({
            "target_dataset": str(target_dataset),
            "zero_shot_accuracy": 100 * args.zs_acc[target_dataset]
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    best_coef = coef_param.data.clone()
    best_acc = args.zs_acc[target_dataset]


    total_params = sum(p.numel() for p in ddp_model.parameters())
    
    is_custom_encoder = args.total_trainable_params is not None and hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
    if is_custom_encoder:
        params_excluding_coef = [p for name, p in ddp_model.named_parameters() if p.requires_grad and 'image_encoder.coef' not in name]
        num_trainable_params_other = sum(p.numel() for p in params_excluding_coef)
        num_trainable_params_coef = int(ddp_model.module.image_encoder.trainable_coef_mask.sum().item())
        num_trainable_params = num_trainable_params_other + num_trainable_params_coef
        print(f"[Param Stats - Custom] Other trainable: {num_trainable_params_other}, Masked coef: {num_trainable_params_coef}, Total effective trainable: {num_trainable_params}")
    else:
        num_trainable_params = sum(p.numel() for p in ddp_model.parameters() if p.requires_grad)
        print(f"[Param Stats - Standard] Total trainable: {num_trainable_params}")

    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {num_trainable_params:,}")
        print(f"Percentage trainable: {100 * num_trainable_params / total_params:.2f}%\n")
        wandb.config.update({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        }, allow_val_change=True)
        
        wandb.log({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        })

    for epoch in range(args.epochs):
        epoch_loss = 0.0
        epoch_steps = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)

                reg = lp_reg(coef_param, args.lp_reg)
                loss = loss + reg
                loss = loss / args.num_grad_accumulation
                
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",
                    flush=True,
                )
                
                batch_log = {
                    "dataset": str(target_dataset),
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "regularization": reg.item() if reg != 0 else 0,
                }
                
                current_coef_tensor = ddp_model.module.image_encoder.coef.data
                
                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                    for task_idx in range(current_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(current_coef_tensor.size(1)):
                            for part_idx in range(current_coef_tensor.size(2)):
                                if current_mask[task_idx, block_idx, part_idx]:
                                    coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                    batch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}"] = coef_value
                elif args.blockwise_coef:
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx in range(current_coef_values.size(0)):
                        layer_coefs = current_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                             batch_log[f"coef_layer{layer_idx}_{dataset_name}"] = coef_value.item()
                elif args.partition:
                     current_coef_values = current_coef_tensor.cpu()
                     for dataset_idx in range(current_coef_values.size(0)):
                         dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                         for block_idx in range(current_coef_values.size(1)):
                             for part_idx in range(current_coef_values.size(2)):
                                 coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                                 batch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}"] = coef_value
                else:
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        batch_log[f"coef_{dataset_name}"] = coef_value.item()
                
                wandb.log(batch_log)
        
        if is_main_process():
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
            })
            
            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            epoch_log = {
                "dataset": str(target_dataset),
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            current_coef_tensor = ddp_model.module.image_encoder.coef.data
            
            is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
            if is_custom_encoder:
                current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                for task_idx in range(current_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(current_coef_tensor.size(1)):
                        for part_idx in range(current_coef_tensor.size(2)):
                            if current_mask[task_idx, block_idx, part_idx]:
                                coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                epoch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            elif args.blockwise_coef:
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     layer_coefs = current_coef_values[dataset_idx]
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for layer_idx, coef_value in enumerate(layer_coefs):
                         epoch_log[f"coef_layer{layer_idx}_{dataset_name}_epoch"] = coef_value.item()
            elif args.partition:
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(current_coef_values.size(1)):
                         for part_idx in range(current_coef_values.size(2)):
                            coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                            epoch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            else:
                current_coef_values = current_coef_tensor.cpu()
                for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    epoch_log[f"coef_{dataset_name}_epoch"] = coef_value.item()
                
            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%")
            
            if val_acc > best_acc:
                best_acc = val_acc
                best_coef = coef_param.data.clone()
                torch.save(best_coef, head_path)
                
                best_log = {
                    "target_dataset": target_dataset,
                    "best_val_accuracy": 100 * best_acc,
                    "best_epoch": epoch,
                }
                
                best_coef_tensor = best_coef.cpu()

                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                    for task_idx in range(best_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(best_coef_tensor.size(1)):
                            for part_idx in range(best_coef_tensor.size(2)):
                                if best_mask[task_idx, block_idx, part_idx]:
                                    coef_value = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                    best_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                elif args.blockwise_coef:
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        layer_coefs = best_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                            best_log[f"coef_layer{layer_idx}_{dataset_name}_best"] = coef_value.item()
                elif args.partition:
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for block_idx in range(best_coef_values.size(1)):
                            for part_idx in range(best_coef_values.size(2)):
                                coef_value = best_coef_values[dataset_idx, block_idx, part_idx].item()
                                best_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                else:
                    best_coef_values = best_coef_tensor
                    for dataset_idx, coef_value in enumerate(best_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        best_log[f"coef_{dataset_name}_best"] = coef_value.item()
                
                wandb.log(best_log)

    if is_main_process():
        comp_acc[target_dataset] = best_acc
        target_dataset_test = target_dataset.replace("Val", "")
        image_encoder = ddp_model.module.image_encoder
        
        is_custom_encoder = hasattr(image_encoder, 'trainable_coef_mask')
        if is_custom_encoder:
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
                print("Loaded best_coef (3D tensor) into CustomWeightedImageEncoder")
            else:
                print("Warning: best_coef is None, cannot load into CustomWeightedImageEncoder")
        elif linearized_finetuning:
            if best_coef is not None:
                 image_encoder.model.coef = torch.nn.Parameter(best_coef)
            else:
                print("Warning: best_coef is None, cannot load into WeightedLinearizedModel")
        else:
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
            else:
                 print("Warning: best_coef is None, cannot load into WeightedImageEncoder")

        print("best_coef shape after loading:", image_encoder.coef.shape if best_coef is not None else "N/A")
        
        final_coef_log = {}
        if best_coef is not None:
            best_coef_tensor = best_coef.cpu()
            if is_custom_encoder:
                best_mask = image_encoder.trainable_coef_mask.data.bool().cpu()
                for task_idx in range(best_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(best_coef_tensor.size(1)):
                        for part_idx in range(best_coef_tensor.size(2)):
                            if best_mask[task_idx, block_idx, part_idx]:
                                val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                final_coef_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_final"] = val
            elif args.blockwise_coef:
                best_coef_values = best_coef_tensor
                for dataset_idx in range(best_coef_values.size(0)):
                    layer_coefs = best_coef_values[dataset_idx]
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    for layer_idx, val in enumerate(layer_coefs):
                        final_coef_log[f"coef_layer{layer_idx}_{dataset_name}_final"] = val.item()
            elif args.partition:
                 best_coef_values = best_coef_tensor
                 for dataset_idx in range(best_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(best_coef_values.size(1)):
                         for part_idx in range(best_coef_values.size(2)):
                             val = best_coef_values[dataset_idx, block_idx, part_idx].item()
                             final_coef_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_final"] = val
            else:
                best_coef_values = best_coef_tensor
                for dataset_idx, val in enumerate(best_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    final_coef_log[f"coef_{dataset_name}_final"] = val.item()
        else:
            print("Warning: best_coef is None, skipping final coefficient logging.")

        wandb.log(final_coef_log)
        
        image_encoder = ddp_model.module.image_encoder
            
        test_metrics = eval_single_dataset(image_encoder, target_dataset_test, args)
        final_test_acc = test_metrics["top1"]
        comp_acc[target_dataset_test] = final_test_acc
        
        wandb.log({
            "dataset": str(target_dataset),
            "final_val_accuracy": 100 * best_acc,
            "final_test_accuracy": 100 * final_test_acc,
        })
        
        print(f"Final test accuracy on {target_dataset_test}: {100 * final_test_acc:.2f}%")
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        test_acc_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset_test}_{subsample_str}_test_accuracy.json")
        
        test_acc_data = {
            "target_dataset": str(target_dataset_test),
            "number_of_used_datasets": len(args.datasets),
            "datasets_names": args.datasets,
            "test_accuracy": float(final_test_acc),
            "test_metrics": test_metrics,
        }
        
        if os.path.exists(test_acc_path):
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            existing_data.append(test_acc_data)
            
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}")
            
        if best_coef is not None:
            try:
                coef_data = []
                if best_coef is not None:
                    best_coef_tensor = best_coef.cpu()
                    
                    is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                    if is_custom_encoder:
                        best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                        for task_idx in range(best_coef_tensor.size(0)):
                            task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                            for block_idx in range(best_coef_tensor.size(1)):
                                for part_idx in range(best_coef_tensor.size(2)):
                                    if best_mask[task_idx, block_idx, part_idx]:
                                        val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                        coef_data.append([task_idx, task_name, block_idx, part_idx, val])
                        
                        coef_table = wandb.Table(
                            columns=["task_idx", "dataset", "block_idx", "part_idx", "coefficient_value"],
                            data=coef_data
                        )
                        wandb.log({"final_coefficient_values_table": coef_table})
                        
                        active_coef_values = best_coef_tensor[best_mask]
                        if active_coef_values.numel() > 0:
                           wandb.log({"final_active_coefficient_histogram": wandb.Histogram(active_coef_values.numpy())})
                        else:
                            print("No active coefficients found for histogram logging.")

                    elif args.blockwise_coef:
                        best_coef_values = best_coef_tensor
                        for dataset_idx in range(best_coef_values.size(0)):
                            layer_coefs = best_coef_values[dataset_idx]
                            dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                            for layer_idx, val in enumerate(layer_coefs):
                                coef_data.append([layer_idx, dataset_idx, dataset_name, val.item()])
                        
                        coef_table = wandb.Table(
                            columns=["layer", "index", "dataset", "coefficient_value"], 
                            data=coef_data
                        )
                        wandb.log({"coefficient_values_table": coef_table})
                        
                        for layer_idx in range(best_coef_values.size(0)):
                             wandb.log({
                                 f"coefficient_histogram_layer_{layer_idx}": 
                                 wandb.Histogram(best_coef_values[layer_idx].numpy())
                             })
                    elif args.partition:
                        best_coef_values = best_coef_tensor
                        print("WARNING: wandb.Table logging for original --partition mode not explicitly adapted. Logging histogram only.")
                        wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                    else:
                         best_coef_values = best_coef_tensor
                         for dataset_idx, val in enumerate(best_coef_values.flatten()):
                             dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                             coef_data.append([dataset_idx, dataset_name, val.item()])
                        
                         coef_table = wandb.Table(
                             columns=["index", "dataset", "coefficient_value"], 
                             data=coef_data
                         )
                         wandb.log({"coefficient_values_table": coef_table})
                         wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                else:
                     print("Warning: best_coef is None, cannot log final coefficient table/histogram.")
            except Exception as e:
                print(f"Error logging coefficients to wandb: {e}")

    
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}")
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)")
        
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    args = parse_arguments()
    
    
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("to ensure that experiment code states are tracked.")
    print("=" * 80)
    
    args.datasets = target_datasets
    args.lr = 1e-1
    args.epochs = 10
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1

    args.print_every = 10
    
    args.seed = 0

    args.save = args.save + f"{args.model}"
    if args.subsample is not None:
        pass

    print("=" * 80)
    print("NOTE: When running multiple instances, use different ports to avoid conflicts")
    print("Example: python src/learn_coef.py --port=29501")
    print("Available port range: 29500-29509")
    print("=" * 80)

    with open(os.path.join(args.save, "zeroshot_accuracies.json"), 'r') as f:
        args.zs_acc = json.load(f)
    
    
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
