import os
import time
import json
import torch
import torchvision
import sys
import argparse

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_DISABLE_SERVICE"] = "true"

old_stderr = sys.stderr
with open('/dev/null', 'w') as devnull:
    sys.stderr = devnull
    import wandb
sys.stderr = old_stderr

import random
import numpy as np
import socket
import subprocess
import platform
import psutil
from datetime import datetime, timedelta

from PIL import Image as PILImage

from src.corruptions import (
    gaussian_noise, shot_noise, impulse_noise, speckle_noise,
    gaussian_blur, defocus_blur, zoom_blur,
    contrast, brightness, saturate, jpeg_compression, pixelate
)

from tqdm import tqdm

from torch.cuda.amp import GradScaler
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageEncoder, ImageClassifier
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.composition import WeightedImageEncoder, WeightedLinearizedModel

from src.utils import cosine_lr
from src.args import parse_arguments
from src.eval import eval_single_dataset
from src.datasets.registry import get_dataset
from src.heads import get_classification_head
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def gaussian_noise_pil(img, severity=1):
    img_array = gaussian_noise(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def shot_noise_pil(img, severity=1):
    img_array = shot_noise(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def impulse_noise_pil(img, severity=1):
    return impulse_noise(img, severity)

def speckle_noise_pil(img, severity=1):
    img_array = speckle_noise(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def gaussian_blur_pil(img, severity=1):
    img_array = gaussian_blur(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def defocus_blur_pil(img, severity=1):
    img_array = defocus_blur(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def zoom_blur_pil(img, severity=1):
    img_array = zoom_blur(img, severity)
    return PILImage.fromarray(img_array.astype(np.uint8))

def contrast_pil(img, severity=1):
    result = contrast(img, severity)
    return result

def brightness_pil(img, severity=1):
    result = brightness(img, severity)
    return result

def saturate_pil(img, severity=1):
    result = saturate(img, severity)
    return result

def jpeg_compression_pil(img, severity=1):
    return jpeg_compression(img, severity)

def pixelate_pil(img, severity=1):
    return pixelate(img, severity)

def lp_reg(x, p=None, gamma=0.5) -> torch.Tensor:
    return 0 if p is None else gamma * torch.norm(x, p=p, dim=0).mean()

def set_seed(seed: int) -> None:
    """
    Set random seed for all possible random number generators for reproducibility.
    
    Args:
        seed: The random seed to set
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if is_main_process():
        print(f"Random seed set to {seed} for deterministic results")


def is_port_available(port):
    """
    Check if a port is available for use by attempting to bind to it.
    
    Args:
        port: The port number to check
        
    Returns:
        bool: True if the port is available, False otherwise
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(('', port))
            return True
        except OSError:
            return False

def find_available_port(port_list):
    """
    Find the first available port from a list of ports.
    
    Args:
        port_list: List of port numbers to check
        
    Returns:
        int: First available port from the list or None if none available
    """
    import socket
    import time
    import random
    
    port_list = list(port_list)
    random.shuffle(port_list)
    
    for port in port_list:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
            tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                tcp_socket.bind(('', port))
                time.sleep(0.1)
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as second_check:
                    second_check.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                    try:
                        second_check.bind(('', port))
                        print(f"Port {port} is confirmed available")
                        return port
                    except OSError:
                        print(f"Port {port} failed second availability check")
                        continue
            except OSError as e:
                print(f"Port {port} is not available: {e}")
                continue
    
    print("No available ports found in the provided list")
    return None

class CorruptionTransform:
    """Picklable corruption transform class for multiprocessing compatibility."""
    
    def __init__(self, corruption_func, severity):
        self.corruption_func = corruption_func
        self.severity = severity
        self.corruption_name = getattr(corruption_func, '__name__', str(corruption_func))
    
    def __call__(self, img):
        try:
            return self.corruption_func(img, severity=self.severity)
        except Exception as e:
            print(f"Error in corruption {self.corruption_name} with severity {self.severity}: {e}", flush=True)
            return img
    
    def __repr__(self):
        return f"CorruptionTransform({self.corruption_name}, severity={self.severity})"

def save_first_corrupted_image(image_tensor, dataset_name, corruption_func, severity):
    """
    Save the first image from the first batch to disk for visualization.
    
    Args:
        image_tensor: Tensor of shape (C, H, W) representing the corrupted image
        dataset_name: Name of the dataset
        corruption_func: The corruption function used
        severity: Severity level of the corruption
    """
    try:
        output_dir = "/check/corruptions"
        os.makedirs(output_dir, exist_ok=True)
        
        corruption_name = getattr(corruption_func, '__name__', str(corruption_func))
        if corruption_name.endswith('_pil'):
            corruption_name = corruption_name[:-4]
        
        image_np = image_tensor.cpu().clone()
        
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        
        image_np = image_np * std + mean
        
        image_np = torch.clamp(image_np, 0, 1)
        
        image_np = image_np.permute(1, 2, 0).numpy()
        image_np = (image_np * 255).astype(np.uint8)
        
        pil_image = PILImage.fromarray(image_np)
        
        filename = f"{dataset_name}_{corruption_name}_severity{severity}_example.png"
        filepath = os.path.join(output_dir, filename)
        
        pil_image.save(filepath)
        
        print(f"[DEBUG] Saved corrupted image example: {filepath}", flush=True)
        
    except Exception as e:
        print(f"[WARNING] Failed to save corrupted image: {e}", flush=True)

def evaluate_with_corruption(image_encoder, dataset_name, args, corruption_func, severity):
    """
    Evaluates the model on a dataset with a given corruption function and severity.
    """
    if not is_main_process():
        return None

    print(f"[DEBUG] Starting evaluation with {getattr(corruption_func, '__name__', str(corruption_func))} severity {severity}", flush=True)
    
    image_encoder.eval()

    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)
    model = model.cuda()
    model.eval()

    corruption_transform = CorruptionTransform(corruption_func, severity)
    print(f"[DEBUG] Created corruption transform: {corruption_transform}", flush=True)

    original_preprocess = model.val_preprocess
    
    if not hasattr(original_preprocess, 'transforms'):
         print(f"Error: val_preprocess for {dataset_name} is not a Compose object.", flush=True)
         return {"top1": 0.0}

    new_transforms_list = []
    
    to_tensor_idx = -1
    for i, t in enumerate(original_preprocess.transforms):
        if isinstance(t, torchvision.transforms.ToTensor):
            to_tensor_idx = i
            break
            
    if to_tensor_idx == -1:
        print(f"Warning: ToTensor not found. Adding corruption to the beginning of transforms for {dataset_name}.", flush=True)
        new_transforms_list.append(corruption_transform)
        new_transforms_list.extend(original_preprocess.transforms)
    else:
        temp_list = list(original_preprocess.transforms)
        temp_list.insert(to_tensor_idx, corruption_transform)
        new_transforms_list = temp_list

    corrupted_preprocess = torchvision.transforms.Compose(new_transforms_list)
    print(f"[DEBUG] Created corrupted preprocessing pipeline with {len(new_transforms_list)} transforms", flush=True)

    print(f"[DEBUG] Getting dataset {dataset_name} with corrupted preprocessing...", flush=True)
    dataset = get_dataset(
        dataset_name,
        corrupted_preprocess,
        location=args.data_location,
        batch_size=args.batch_size
    )
    
    print(f"[DEBUG] Creating dataloader with num_workers=4...", flush=True)
    
    temp_args = argparse.Namespace(**vars(args))
    temp_args.num_workers = 4
    
    dataloader = get_dataloader(
        dataset, is_train=False, args=temp_args, image_encoder=None
    )
    
    if hasattr(dataloader, 'num_workers'):
        dataloader.num_workers = 4
    
    print(f"[DEBUG] Starting evaluation loop with {len(dataloader)} batches...", flush=True)
    
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Evaluating {getattr(corruption_func, '__name__', 'corruption')} sev{severity}", 
                   leave=False, ncols=100)
        
        for batch_idx, batch in enumerate(pbar):
            batch = maybe_dictionarize(batch)
            images, labels = batch["images"].cuda(), batch["labels"].cuda()
            
            logits = model(images)
            preds = logits.argmax(dim=-1)
            batch_correct = (preds == labels).sum().item()
            total_correct += batch_correct
            total_samples += labels.size(0)
            
            current_acc = total_correct / total_samples if total_samples > 0 else 0.0
            pbar.set_postfix({'acc': f'{current_acc:.3f}', 'samples': total_samples})

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"[DEBUG] Completed evaluation: {total_correct}/{total_samples} = {accuracy:.4f}", flush=True)
    
    return {"top1": accuracy}

def main(rank, args):
    args.rank = rank
    
    distributed_initialized = False
    
    try:
        available_ports = list(range(29520, 29590))
        if hasattr(args, 'port') and args.port is not None and args.port > 0:
            selected_port = args.port
            print(f"Using user-specified port {selected_port} for distributed training")
        else:
            selected_port = find_available_port(available_ports)
            if selected_port is None:
                print("Warning: No available ports found. Using a random port which may cause issues.")
                selected_port = random.choice(available_ports)
                print(f"Selected random port {selected_port} - this may cause issues if already in use")
            else:
                print(f"Found available port {selected_port} for distributed training")

        args.port = selected_port
        
        setup_ddp(args.rank, args.world_size, port=selected_port)
        distributed_initialized = True
        
        if args.seed is not None:
            set_seed(args.seed)
        
        pool = [
            "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
            "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
            "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
        ]

        args.target_dataset = args.target_dataset_name + "Val"
        if args.target_dataset_name in args.datasets:
            args.datasets.remove(args.target_dataset_name)
        original_datasets = args.datasets.copy()
        
        for idx, dataset in enumerate(original_datasets):

            if idx < args.resume_from_idx:
                print(f"Skipping iteration {idx+1} because resume-from-idx is {args.resume_from_idx}")
                continue

            if idx >= args.end_index:
                print(f"Skipping iteration {idx+1} because end-idx is {args.end_index}")
                continue

            try:
                args.datasets = original_datasets[:idx+1]

                task_vectors = {}
                for dataset in args.datasets:
                    if args.finetuning_mode == "linear":
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
                        task_vectors[dataset] = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
                    else:
                        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
                        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
                        task_vectors[dataset] = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

            
                print("=" * 100)
                print(f"Learning task vector coefficients on {args.target_dataset} with {args.model} using {len(args.datasets)} datasets")
                print(f"Datasets being used: {args.datasets}")
                print("=" * 100)

                print("\nCommand line arguments:")
                print("-" * 50)
                for arg, value in vars(args).items():
                    print(f"{arg}: {value}")
                print("-" * 50)
                print()
                train(task_vectors, args)
                print(f"Successfully completed iteration {idx+1} with {len(args.datasets)} datasets")
                
            except Exception as e:
                print(f"ERROR: Iteration {idx+1} with {len(args.datasets)} datasets failed with error: {e}")
                import traceback
                traceback.print_exc()
                print(f"Continuing with next iteration...")
                continue
            
    finally:
        if distributed_initialized:
            try:
                cleanup_ddp()
            except Exception as e:
                print(f"Warning: Error during distributed cleanup: {e}")
    
    args.datasets = original_datasets

from torch import nn
from functorch import make_functional_with_buffers

class CustomWeightedImageEncoder(nn.Module):
    def __init__(self, model, task_vectors, total_trainable_params, num_task_vectors, num_blocks) -> None:
        """A custom wrapper class for task vector composition based on total trainable parameters.

        Args:
            model: Base image encoder model.
            task_vectors: List of task vectors.
            total_trainable_params: The exact number of coefficients to learn.
            num_task_vectors: Number of task vectors being used.
            num_blocks: Number of parameter blocks in the base model.
        """
        super().__init__()

        if num_task_vectors <= 0 or num_blocks <= 0:
            raise ValueError(f"num_task_vectors ({num_task_vectors}) and num_blocks ({num_blocks}) must be positive.")

        func, params, self.buffer = make_functional_with_buffers(model)
        self.func = lambda p, b, x: func(p, b, x)
        self.params = torch.nn.ParameterList(params)
        for p in self.params:
            p.requires_grad = False

        self.train_preprocess = model.train_preprocess
        self.val_preprocess = model.val_preprocess
        self.cache_dir = model.cache_dir

        self.dparams = [[tv.vector[k] for k in tv.vector] for tv in task_vectors]

        self.num_task_vectors = num_task_vectors
        self.num_blocks = num_blocks
        self.total_trainable_params = total_trainable_params

        params_per_full_partition = self.num_task_vectors * self.num_blocks
        if params_per_full_partition == 0:
             self.base_partition = 0
             self.remainder_params = 0
             self.max_partition = 1
        else:
            self.base_partition = self.total_trainable_params // params_per_full_partition
            self.remainder_params = self.total_trainable_params % params_per_full_partition
            self.max_partition = self.base_partition + (1 if self.remainder_params > 0 else 0)
            if self.max_partition == 0:
                self.max_partition = 1

        print(f"[CustomEncoder] Total Params: {self.total_trainable_params}, Tasks: {self.num_task_vectors}, Blocks: {self.num_blocks}")
        print(f"[CustomEncoder] Base Partition: {self.base_partition}, Remainder Params: {self.remainder_params}, Max Partition Dim: {self.max_partition}")

        self.coef = torch.nn.Parameter(torch.zeros(self.num_task_vectors, self.num_blocks, self.max_partition))

        trainable_coef_mask = torch.zeros_like(self.coef, dtype=torch.bool, requires_grad=False)

        if self.base_partition > 0:
            trainable_coef_mask[:, :, :self.base_partition] = True

        current_remainder = self.remainder_params
        if current_remainder > 0 and self.base_partition < self.max_partition:
            print(f"[CustomEncoder] Distributing {current_remainder} remainder coefficients...")
            coef_idx_to_distribute = self.base_partition
            assigned_count = 0
            for block_idx in range(self.num_blocks):
                for task_idx in range(self.num_task_vectors):
                    if current_remainder > 0:
                        trainable_coef_mask[task_idx, block_idx, coef_idx_to_distribute] = True
                        current_remainder -= 1
                        assigned_count += 1
                    else:
                        break
                if current_remainder <= 0:
                    break
            print(f"[CustomEncoder] Assigned {assigned_count} remainder coefficients.")

        self.register_buffer('trainable_coef_mask', trainable_coef_mask.float())

        print(f"[CustomEncoder] Creating mask_mats buffers for {self.max_partition} partitions...")
        self._mask_mat_buffer_shapes = {}
        processed_shapes = set()
        for i, p in enumerate(self.params):
            p_shape = tuple(p.shape)
            if p_shape not in processed_shapes:
                if p.ndim == 0:
                    print(f"    - Skipping scalar parameter at index {i}")
                    continue
                print(f"    - Processing shape {p_shape} (Block {i})")
                assignments = torch.randint(0, self.max_partition, p_shape, device='cpu')
                mask_mat = torch.nn.functional.one_hot(assignments, num_classes=self.max_partition).permute(-1, *range(p.ndim)).float()
                
                buffer_name = f"mask_mat_{'_'.join(map(str, p_shape))}"
                self.register_buffer(buffer_name, mask_mat)
                self._mask_mat_buffer_shapes[p_shape] = buffer_name
                processed_shapes.add(p_shape)
                print(f"      - Created buffer '{buffer_name}' with shape: {mask_mat.shape}")
        print(f"[CustomEncoder] Finished creating mask_mats buffers for {len(processed_shapes)} unique shapes.")

    def _apply(self, fn):
        new_self = super()._apply(fn=fn)
        new_self.dparams = [[fn(x) if isinstance(x, torch.Tensor) else x for x in tv] for tv in new_self.dparams]
        for shape, buffer_name in new_self._mask_mat_buffer_shapes.items():
            if hasattr(new_self, buffer_name):
                 current_buffer = getattr(new_self, buffer_name)
                 if isinstance(current_buffer, torch.Tensor):
                     setattr(new_self, buffer_name, fn(current_buffer))
            else:
                print(f"Warning: Buffer {buffer_name} not found during _apply")
        if hasattr(new_self, 'trainable_coef_mask') and isinstance(new_self.trainable_coef_mask, torch.Tensor):
            new_self.trainable_coef_mask = fn(new_self.trainable_coef_mask)

        return new_self

    def train(self, mode=True):
        super().train(mode)

    def forward(self, x) -> torch.Tensor:
        effective_coefs = self.coef * self.trainable_coef_mask

        final_dparams = []
        for block_idx, p_orig in enumerate(self.params):
            p_shape = tuple(p_orig.shape)
            
            if p_shape in self._mask_mat_buffer_shapes:
                buffer_name = self._mask_mat_buffer_shapes[p_shape]
                block_mask_mat = getattr(self, buffer_name)
            elif p_orig.ndim == 0:
                 final_dparams.append(torch.tensor(0.0, device=p_orig.device, dtype=p_orig.dtype))
                 continue
            else:
                 raise RuntimeError(f"Mask matrix buffer not found for parameter shape {p_shape} at block index {block_idx}")

            block_task_deltas = [self.dparams[task_idx][block_idx] for task_idx in range(self.num_task_vectors)]
            
            valid_deltas = [delta for delta in block_task_deltas if isinstance(delta, torch.Tensor)]
            if not valid_deltas:
                 final_dparams.append(torch.zeros_like(p_orig))
                 continue
            stacked_deltas = torch.stack(valid_deltas, dim=0)

            block_effective_coefs = effective_coefs[:, block_idx, :]
            
            T = stacked_deltas.shape[0]
            P = block_mask_mat.shape[0]
            shape_dims = stacked_deltas.shape[1:]
            view_T_P_1s = (T, P) + (1,) * len(shape_dims)
            view_1_P_shape = (1, P) + shape_dims
            view_T_1_shape = (T, 1) + shape_dims

            term = block_effective_coefs.view(view_T_P_1s) * \
                   block_mask_mat.view(view_1_P_shape) * \
                   stacked_deltas.view(view_T_1_shape)
            
            final_delta_b = term.sum(dim=(0, 1))
            
            final_dparams.append(final_delta_b)

        new_params = [(dp + p) if isinstance(dp, torch.Tensor) else p for dp, p in zip(final_dparams, self.params)]

        return self.func(new_params, self.buffer, x)

def train(task_vectors, args):

    if is_main_process():
        experiment_start_time = datetime.now()
        formatted_start_time = experiment_start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} started at: {formatted_start_time}")
    
    subsample_str = f"s{int(args.subsample * 100)}" if isinstance(args.subsample, float) else f"s{args.subsample}"

    if args.seed is not None:
        set_seed(args.seed + args.rank)

    target_dataset = args.target_dataset
    
    experiment_commit = os.environ.get('EXPERIMENT_COMMIT_HASH', 'unknown')
    
    if is_main_process():
        gpu_info = {}
        if torch.cuda.is_available():
            gpu_info = {
                "gpu_count": torch.cuda.device_count(),
                "gpu_model": torch.cuda.get_device_name(0),
                "gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
            }
        wandb.login(key="your-wandb-key")
        wandb.init(
            project="axis",
            entity="example-owner",
            name=f"{args.model}_{target_dataset}_{subsample_str}-blkcoef-{args.blockwise_coef}-e{args.epochs}-nbdts{len(args.datasets)}-part{args.partition}",
            config={
                "type": "learn_coef",
                "kind": "baseline-10pct-corruption",
                "iter_fixed": 3,
                "model": args.model,
                "save": args.save,
                "target_dataset": target_dataset,
                "used_datasets": args.datasets,
                "partition": args.partition,
                "#nb_dts": len(args.datasets),
                "port": args.port,
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "batch_size": args.batch_size * args.num_grad_accumulation,
                "finetuning_mode": args.finetuning_mode,
                "lp_reg": args.lp_reg,
                "seed": args.seed,
                "weight_decay": args.wd,
                "blockwise_coef": True,
                "git_commit": experiment_commit,
                "subsample": args.subsample,
                
                "system": {
                    "gpu": gpu_info,
                    "cpu_count": psutil.cpu_count(),
                    "cpu_physical_count": psutil.cpu_count(logical=False),
                    "memory_gb": psutil.virtual_memory().total / (1024**3),
                    "hostname": socket.gethostname(),
                    "platform": platform.platform(),
                },
                
                "runtime": {
                    "pytorch_version": torch.__version__,
                    "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
                    "python_version": sys.version.split()[0],
                },
                
                "job_id": os.environ.get('SLURM_JOB_ID', 'unknown'),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                
                "grad_clip_value": 1.0,
                "num_grad_accumulation": args.num_grad_accumulation,
                "batch_size": args.batch_size,
                "effective_batch_size": args.batch_size * args.num_grad_accumulation * args.world_size,
                "mixed_precision": True,
                
                "port_selection_method": "auto_find_available",
            }
        )

        if wandb.run:
            print(f"WandB Run URL: {wandb.run.url}")
        print("-" * 50 + "\n")

        wandb.log({"git_commit": experiment_commit})
    
    ckpdir = os.path.join(args.save, target_dataset)
    os.makedirs(ckpdir, exist_ok=True)

    assert args.finetuning_mode in [
        "linear",
        "standard",
    ], "Only linear and standard fine-tuning are supported."

    linearized_finetuning = args.finetuning_mode == "linear"
    if linearized_finetuning:
        print("Using linearized fine-tuning.")

    orig_dataset = target_dataset.replace("Val", "")
    pool = [
        "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
        "CIFAR10", "CIFAR100", "ImageNet", "STL10", "Food101", "Caltech101", "Caltech256",
        "FGVCAircraft", "Flowers102", "OxfordIIITPet", "CUB200", "PascalVOC", "Country211", "UCF101",
    ]
    dataset_names = [ds for ds in pool if ds != orig_dataset]
    
    task_vectors = [v for k, v in task_vectors.items() if orig_dataset != k]
    
    coef_dataset_mapping = {i: dataset_names[i] for i in range(len(dataset_names))}
    if is_main_process():
        print("Coefficient to dataset mapping:")
        for idx, ds_name in coef_dataset_mapping.items():
            print(f"Coefficient {idx}: {ds_name}")
    
    if args.finetuning_mode == "linear":
        print("WARNING: --total-trainable-params currently only supported for standard finetuning mode. Using original WeightedLinearizedModel.")
        image_encoder = LinearizedImageEncoder(args, keep_lang=False)
        image_encoder.model = WeightedLinearizedModel(
            image_encoder.model, task_vectors, blockwise=args.blockwise_coef, partition=args.partition
        )
    elif args.total_trainable_params is not None:
        print(f"Using CustomWeightedImageEncoder with total_trainable_params = {args.total_trainable_params}")
        temp_base_encoder = ImageEncoder(args)
        _, temp_params, _ = make_functional_with_buffers(temp_base_encoder)
        num_blocks = len(temp_params)
        num_task_vectors = len(task_vectors)
        del temp_base_encoder, temp_params
        
        image_encoder = CustomWeightedImageEncoder(
            model=ImageEncoder(args),
            task_vectors=task_vectors,
            total_trainable_params=args.total_trainable_params,
            num_task_vectors=num_task_vectors,
            num_blocks=num_blocks
        )
        print(f"Created CustomWeightedImageEncoder with {num_task_vectors} tasks, {num_blocks} blocks.")

        if is_main_process():
            print("\nParameters of the CustomWeightedImageEncoder (Trainable Coefs):")
            param_counter = 1
            total_params_in_coef = 0
            for name, param in image_encoder.named_parameters():
                if param.requires_grad:
                    print(f"{param_counter}. {name} (Shape: {list(param.shape)}, Type: {param.dtype}) - Trainable")
                    total_params_in_coef += param.numel()
                    param_counter += 1
            print(f"Total elements in trainable 'coef' parameter: {total_params_in_coef}")
            if hasattr(image_encoder, 'trainable_coef_mask'):
                num_masked_trainable = image_encoder.trainable_coef_mask.sum().item()
                print(f"Number of trainable coefficients according to mask: {int(num_masked_trainable)}")
                if int(num_masked_trainable) != args.total_trainable_params:
                     print(f"WARNING: Masked trainable count ({int(num_masked_trainable)}) does not match requested total_trainable_params ({args.total_trainable_params})!")
            print("-" * 30)


    classification_head = get_classification_head(args, target_dataset)
    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()
    model = model.cuda()

    preprocess_fn = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(
            size=224, scale=(0.5, 1),
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ), torchvision.transforms.RandomHorizontalFlip(p=0.5),
    ] + model.train_preprocess.transforms[-3:])


    dataset = get_dataset(
        target_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_workers=4,
    )
    data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(data_loader)

    if args.print_every * 10 < num_batches:
        print_every = int(num_batches / 10)
    elif args.print_every * 4 > num_batches:
        print_every = max(int(num_batches / 4), 1)
    else:
        print_every = args.print_every

    ddp_loader = distribute_loader(data_loader)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.rank],
        find_unused_parameters=False,
        output_device=args.rank,
    )

    loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in ddp_model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    scheduler = cosine_lr(
        optimizer, args.lr, 0,
        args.epochs * num_batches // args.num_grad_accumulation,
    )

    slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
    if "unknown" in slurm_job_id:
        slurm_job_id = time.strftime("%Y%m%d-%H%M%S")
    

    if args.total_trainable_params is not None:
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_custom_composition_total{args.total_trainable_params}.json")
        coef_param = ddp_model.module.image_encoder.coef
        coef_mask = ddp_model.module.image_encoder.trainable_coef_mask
        print("Using CustomWeightedImageEncoder - paths set.")
        print(f"Coef parameter shape: {coef_param.shape}, Mask shape: {coef_mask.shape}")
    elif linearized_finetuning:
        head_path = os.path.join(ckpdir, f"learned_linear_composition_{subsample_str}.pt")
        log_path = os.path.join(args.save, f"learned_linear_composition_{subsample_str}.json")
        coef_param = ddp_model.module.image_encoder.model.coef
        coef_mask = None
        print("Using Linearized fine-tuning - paths set.")
    else:
        head_path = os.path.join(ckpdir, f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.pt")
        log_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset}_{subsample_str}_learned_composition.json")
        coef_param = ddp_model.module.image_encoder.coef
        coef_mask = None
        print("Using standard WeightedImageEncoder - paths set.")

    print("head_path:", head_path)
    print("log_path:", log_path)
    print("Initial coef param shape:", coef_param.shape)

    scaler = GradScaler()
    if is_main_process():
        print(f"=> Zero-shot accuracy on {target_dataset}:\t{100*args.zs_acc[target_dataset]:.2f}%.")
        wandb.log({
            "target_dataset": str(target_dataset),
            "zero_shot_accuracy": 100 * args.zs_acc[target_dataset]
        })
        
        if os.path.exists(log_path):
            with open(log_path, 'r') as f:
                comp_acc = json.load(f)
        else:
            comp_acc = {}

    best_coef = coef_param.data.clone()
    best_acc = args.zs_acc[target_dataset]


    total_params = sum(p.numel() for p in ddp_model.parameters())
    
    is_custom_encoder = args.total_trainable_params is not None and hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
    if is_custom_encoder:
        params_excluding_coef = [p for name, p in ddp_model.named_parameters() if p.requires_grad and 'image_encoder.coef' not in name]
        num_trainable_params_other = sum(p.numel() for p in params_excluding_coef)
        num_trainable_params_coef = int(ddp_model.module.image_encoder.trainable_coef_mask.sum().item())
        num_trainable_params = num_trainable_params_other + num_trainable_params_coef
        print(f"[Param Stats - Custom] Other trainable: {num_trainable_params_other}, Masked coef: {num_trainable_params_coef}, Total effective trainable: {num_trainable_params}")
    else:
        num_trainable_params = sum(p.numel() for p in ddp_model.parameters() if p.requires_grad)
        print(f"[Param Stats - Standard] Total trainable: {num_trainable_params}")

    if is_main_process():
        print(f"\nModel Parameter Stats:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {num_trainable_params:,}")
        print(f"Percentage trainable: {100 * num_trainable_params / total_params:.2f}%\n")
        wandb.config.update({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        }, allow_val_change=True)
        
        wandb.log({
            "total_parameters": total_params,
            "trainable_parameters": num_trainable_params,
            "trainable_parameters_pct": 100 * num_trainable_params / total_params
        })

    for epoch in range(args.epochs):
        epoch_loss = 0.0
        epoch_steps = 0
        
        ddp_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(ddp_loader):
            start_time = time.time()

            step = (
                i // args.num_grad_accumulation
                + epoch * num_batches // args.num_grad_accumulation
            )

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].cuda()
            data_time = time.time() - start_time

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = ddp_model(inputs)
                labels = batch["labels"].cuda()
                loss = loss_fn(logits, labels)

                reg = lp_reg(coef_param, args.lp_reg)
                loss = loss + reg
                loss = loss / args.num_grad_accumulation
                
                if is_main_process():
                    epoch_loss += loss.item() * args.num_grad_accumulation
                    epoch_steps += 1

            scaler.scale(loss).backward()

            if (i + 1) % args.num_grad_accumulation == 0:
                scheduler(step)

                torch.nn.utils.clip_grad_norm_(params, 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_time = time.time() - start_time

            if (
                step % print_every == 0
                and ((i + 1) % args.num_grad_accumulation == 0)
                and is_main_process()
            ):
                percent_complete = 100 * (i + 1) / len(ddp_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i + 1}/{num_batches}]\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}",
                    flush=True,
                )
                
                batch_log = {
                    "dataset": str(target_dataset),
                    "epoch": epoch,
                    "batch": i,
                    "batch_loss": loss.item() * args.num_grad_accumulation,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "batch_time": batch_time,
                    "data_time": data_time,
                    "global_step": step,
                    "regularization": reg.item() if reg != 0 else 0,
                }
                
                current_coef_tensor = ddp_model.module.image_encoder.coef.data
                
                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                    for task_idx in range(current_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(current_coef_tensor.size(1)):
                            for part_idx in range(current_coef_tensor.size(2)):
                                if current_mask[task_idx, block_idx, part_idx]:
                                    coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                    batch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}"] = coef_value
                elif args.blockwise_coef:
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx in range(current_coef_values.size(0)):
                        layer_coefs = current_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                             batch_log[f"coef_layer{layer_idx}_{dataset_name}"] = coef_value.item()
                elif args.partition:
                     current_coef_values = current_coef_tensor.cpu()
                     for dataset_idx in range(current_coef_values.size(0)):
                         dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                         for block_idx in range(current_coef_values.size(1)):
                             for part_idx in range(current_coef_values.size(2)):
                                 coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                                 batch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}"] = coef_value
                else:
                    current_coef_values = current_coef_tensor.cpu()
                    for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        batch_log[f"coef_{dataset_name}"] = coef_value.item()
                
                wandb.log(batch_log)
        
        if is_main_process():
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            wandb.log({
                "target_dataset": str(target_dataset),
                "epoch": epoch,
                "epoch_loss": avg_epoch_loss,
            })
            
            image_encoder = ddp_model.module.image_encoder
            val_metrics = eval_single_dataset(image_encoder, target_dataset, args)
            val_acc = val_metrics["top1"]
            
            epoch_log = {
                "dataset": str(target_dataset),
                "epoch": epoch,
                "val_accuracy": 100 * val_acc,
            }
            
            current_coef_tensor = ddp_model.module.image_encoder.coef.data
            
            is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
            if is_custom_encoder:
                current_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool()
                for task_idx in range(current_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(current_coef_tensor.size(1)):
                        for part_idx in range(current_coef_tensor.size(2)):
                            if current_mask[task_idx, block_idx, part_idx]:
                                coef_value = current_coef_tensor[task_idx, block_idx, part_idx].item()
                                epoch_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            elif args.blockwise_coef:
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     layer_coefs = current_coef_values[dataset_idx]
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for layer_idx, coef_value in enumerate(layer_coefs):
                         epoch_log[f"coef_layer{layer_idx}_{dataset_name}_epoch"] = coef_value.item()
            elif args.partition:
                 current_coef_values = current_coef_tensor.cpu()
                 for dataset_idx in range(current_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(current_coef_values.size(1)):
                         for part_idx in range(current_coef_values.size(2)):
                            coef_value = current_coef_values[dataset_idx, block_idx, part_idx].item()
                            epoch_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_epoch"] = coef_value
            else:
                current_coef_values = current_coef_tensor.cpu()
                for dataset_idx, coef_value in enumerate(current_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    epoch_log[f"coef_{dataset_name}_epoch"] = coef_value.item()
                
            wandb.log(epoch_log)
            
            print(f"Epoch {epoch}: Validation accuracy: {100 * val_acc:.2f}%")
            
            if val_acc > best_acc:
                best_acc = val_acc
                best_coef = coef_param.data.clone()
                torch.save(best_coef, head_path)
                
                best_log = {
                    "target_dataset": target_dataset,
                    "best_val_accuracy": 100 * best_acc,
                    "best_epoch": epoch,
                }
                
                best_coef_tensor = best_coef.cpu()

                is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                if is_custom_encoder:
                    best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                    for task_idx in range(best_coef_tensor.size(0)):
                        task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                        for block_idx in range(best_coef_tensor.size(1)):
                            for part_idx in range(best_coef_tensor.size(2)):
                                if best_mask[task_idx, block_idx, part_idx]:
                                    coef_value = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                    best_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                elif args.blockwise_coef:
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        layer_coefs = best_coef_values[dataset_idx]
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for layer_idx, coef_value in enumerate(layer_coefs):
                            best_log[f"coef_layer{layer_idx}_{dataset_name}_best"] = coef_value.item()
                elif args.partition:
                    best_coef_values = best_coef_tensor
                    for dataset_idx in range(best_coef_values.size(0)):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        for block_idx in range(best_coef_values.size(1)):
                            for part_idx in range(best_coef_values.size(2)):
                                coef_value = best_coef_values[dataset_idx, block_idx, part_idx].item()
                                best_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_best"] = coef_value
                else:
                    best_coef_values = best_coef_tensor
                    for dataset_idx, coef_value in enumerate(best_coef_values.flatten()):
                        dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                        best_log[f"coef_{dataset_name}_best"] = coef_value.item()
                
                wandb.log(best_log)

    if is_main_process():
        comp_acc[target_dataset] = best_acc
        target_dataset_test = target_dataset.replace("Val", "")
        image_encoder = ddp_model.module.image_encoder
        
        is_custom_encoder = hasattr(image_encoder, 'trainable_coef_mask')
        if is_custom_encoder:
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
                print("Loaded best_coef (3D tensor) into CustomWeightedImageEncoder")
            else:
                print("Warning: best_coef is None, cannot load into CustomWeightedImageEncoder")
        elif linearized_finetuning:
            if best_coef is not None:
                 image_encoder.model.coef = torch.nn.Parameter(best_coef)
            else:
                print("Warning: best_coef is None, cannot load into WeightedLinearizedModel")
        else:
            if best_coef is not None:
                image_encoder.coef = torch.nn.Parameter(best_coef)
            else:
                 print("Warning: best_coef is None, cannot load into WeightedImageEncoder")

        print("best_coef shape after loading:", image_encoder.coef.shape if best_coef is not None else "N/A")
        
        final_coef_log = {}
        if best_coef is not None:
            best_coef_tensor = best_coef.cpu()
            if is_custom_encoder:
                best_mask = image_encoder.trainable_coef_mask.data.bool().cpu()
                for task_idx in range(best_coef_tensor.size(0)):
                    task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                    for block_idx in range(best_coef_tensor.size(1)):
                        for part_idx in range(best_coef_tensor.size(2)):
                            if best_mask[task_idx, block_idx, part_idx]:
                                val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                final_coef_log[f"coef_{task_name}_block{block_idx}_part{part_idx}_final"] = val
            elif args.blockwise_coef:
                best_coef_values = best_coef_tensor
                for dataset_idx in range(best_coef_values.size(0)):
                    layer_coefs = best_coef_values[dataset_idx]
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    for layer_idx, val in enumerate(layer_coefs):
                        final_coef_log[f"coef_layer{layer_idx}_{dataset_name}_final"] = val.item()
            elif args.partition:
                 best_coef_values = best_coef_tensor
                 for dataset_idx in range(best_coef_values.size(0)):
                     dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                     for block_idx in range(best_coef_values.size(1)):
                         for part_idx in range(best_coef_values.size(2)):
                             val = best_coef_values[dataset_idx, block_idx, part_idx].item()
                             final_coef_log[f"coef_{dataset_name}_block{block_idx}_part{part_idx}_final"] = val
            else:
                best_coef_values = best_coef_tensor
                for dataset_idx, val in enumerate(best_coef_values.flatten()):
                    dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                    final_coef_log[f"coef_{dataset_name}_final"] = val.item()
        else:
            print("Warning: best_coef is None, skipping final coefficient logging.")

        wandb.log(final_coef_log)
        
        image_encoder = ddp_model.module.image_encoder
            
        test_metrics = eval_single_dataset(image_encoder, target_dataset_test, args)
        final_test_acc = test_metrics["top1"]
        comp_acc[target_dataset_test] = final_test_acc
        
        wandb.log({
            "dataset": str(target_dataset),
            "final_val_accuracy": 100 * best_acc,
            "final_test_accuracy": 100 * final_test_acc,
            "git_commit": experiment_commit,
        })
        
        print(f"Final test accuracy on {target_dataset_test}: {100 * final_test_acc:.2f}%")
        
        with open(log_path, 'w') as f:
            json.dump(comp_acc, f, indent=4)

        test_acc_path = os.path.join('path/to/results', f"slurm-{slurm_job_id}-{target_dataset_test}_{subsample_str}_test_accuracy.json")
        
        test_acc_data = {
            "target_dataset": str(target_dataset_test),
            "number_of_used_datasets": len(args.datasets),
            "datasets_names": args.datasets,
            "test_accuracy": float(final_test_acc),
            "test_metrics": test_metrics,
            "git_commit": experiment_commit,
        }
        
        if os.path.exists(test_acc_path):
            with open(test_acc_path, 'r') as f:
                existing_data = json.load(f)
                
            if not isinstance(existing_data, list):
                existing_data = [existing_data]
                
            existing_data.append(test_acc_data)
            
            with open(test_acc_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
        else:
            with open(test_acc_path, 'w') as f:
                json.dump([test_acc_data], f, indent=4)
        print(f"Test accuracy saved to {test_acc_path}")

        print("\n--- Evaluating robustness to image corruptions ---", flush=True)
        
        corruption_functions = [
            ("gaussian_noise", gaussian_noise_pil),
            ("shot_noise", shot_noise_pil),
            ("impulse_noise", impulse_noise_pil),
            ("speckle_noise", speckle_noise_pil),
            ("gaussian_blur", gaussian_blur_pil),
            ("defocus_blur", defocus_blur_pil),
            ("zoom_blur", zoom_blur_pil),
            ("contrast", contrast_pil),
            ("brightness", brightness_pil),
            ("saturate", saturate_pil),
            ("jpeg_compression", jpeg_compression_pil),
            ("pixelate", pixelate_pil),
        ]
        
        corruption_pbar = tqdm(corruption_functions, desc="Corruption Types", ncols=120, position=0)
        
        for corruption_name, corruption_func in corruption_pbar:
            corruption_pbar.set_description(f"Corruption: {corruption_name}")
            print(f"\n-- Evaluating with {corruption_name} corruption --", flush=True)
            
            severity_pbar = tqdm(range(1, 6), desc=f"{corruption_name} severities", 
                               leave=False, ncols=100, position=1)
            
            for severity in severity_pbar:
                severity_pbar.set_description(f"{corruption_name} sev{severity}")
                print(f"--- Evaluating {corruption_name} severity {severity} ---", flush=True)
                
                start_time = time.time()
                
                metrics = evaluate_with_corruption(
                    image_encoder=image_encoder, 
                    dataset_name=target_dataset_test, 
                    args=args,
                    corruption_func=corruption_func,
                    severity=severity
                )
                
                acc = metrics['top1']
                eval_time = time.time() - start_time
                
                print(f"Accuracy with {corruption_name} severity {severity}: {100 * acc:.2f}% (took {eval_time:.1f}s)", flush=True)
                
                wandb.log({
                    f"test_accuracy_{corruption_name}_severity_{severity}": 100 * acc,
                    "dataset": target_dataset_test,
                    "corruption_type": corruption_name,
                    "severity": severity,
                    "evaluation_time_seconds": eval_time
                })
                
                corruption_pbar.set_postfix({'current_acc': f'{100*acc:.1f}%', 'time': f'{eval_time:.1f}s'})
        
        print(f"\n✅ Completed all corruption evaluations!", flush=True)
            
        if best_coef is not None:
            try:
                coef_data = []
                if best_coef is not None:
                    best_coef_tensor = best_coef.cpu()
                    
                    is_custom_encoder = hasattr(ddp_model.module.image_encoder, 'trainable_coef_mask')
                    if is_custom_encoder:
                        best_mask = ddp_model.module.image_encoder.trainable_coef_mask.data.bool().cpu()
                        for task_idx in range(best_coef_tensor.size(0)):
                            task_name = coef_dataset_mapping.get(task_idx, f"task{task_idx}")
                            for block_idx in range(best_coef_tensor.size(1)):
                                for part_idx in range(best_coef_tensor.size(2)):
                                    if best_mask[task_idx, block_idx, part_idx]:
                                        val = best_coef_tensor[task_idx, block_idx, part_idx].item()
                                        coef_data.append([task_idx, task_name, block_idx, part_idx, val])
                        
                        coef_table = wandb.Table(
                            columns=["task_idx", "dataset", "block_idx", "part_idx", "coefficient_value"],
                            data=coef_data
                        )
                        wandb.log({"final_coefficient_values_table": coef_table})
                        
                        active_coef_values = best_coef_tensor[best_mask]
                        if active_coef_values.numel() > 0:
                           wandb.log({"final_active_coefficient_histogram": wandb.Histogram(active_coef_values.numpy())})
                        else:
                            print("No active coefficients found for histogram logging.")

                    elif args.blockwise_coef:
                        best_coef_values = best_coef_tensor
                        for dataset_idx in range(best_coef_values.size(0)):
                            layer_coefs = best_coef_values[dataset_idx]
                            dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                            for layer_idx, val in enumerate(layer_coefs):
                                coef_data.append([layer_idx, dataset_idx, dataset_name, val.item()])
                        
                        coef_table = wandb.Table(
                            columns=["layer", "index", "dataset", "coefficient_value"], 
                            data=coef_data
                        )
                        wandb.log({"coefficient_values_table": coef_table})
                        
                        for layer_idx in range(best_coef_values.size(0)):
                             wandb.log({
                                 f"coefficient_histogram_layer_{layer_idx}": 
                                 wandb.Histogram(best_coef_values[layer_idx].numpy())
                             })
                    elif args.partition:
                        best_coef_values = best_coef_tensor
                        print("WARNING: wandb.Table logging for original --partition mode not explicitly adapted. Logging histogram only.")
                        wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                    else:
                         best_coef_values = best_coef_tensor
                         for dataset_idx, val in enumerate(best_coef_values.flatten()):
                             dataset_name = coef_dataset_mapping.get(dataset_idx, f"unknown_{dataset_idx}")
                             coef_data.append([dataset_idx, dataset_name, val.item()])
                        
                         coef_table = wandb.Table(
                             columns=["index", "dataset", "coefficient_value"], 
                             data=coef_data
                         )
                         wandb.log({"coefficient_values_table": coef_table})
                         wandb.log({"coefficient_histogram": wandb.Histogram(best_coef_values.numpy())})
                else:
                     print("Warning: best_coef is None, cannot log final coefficient table/histogram.")
            except Exception as e:
                print(f"Error logging coefficients to wandb: {e}")

    
    experiment_end_time = datetime.now()
    experiment_duration = experiment_end_time - experiment_start_time
    formatted_end_time = experiment_end_time.strftime("%Y-%m-%d %H:%M:%S")
    
    hours, remainder = divmod(experiment_duration.total_seconds(), 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
    
    if is_main_process():
        print(f"\n[TIMING] Experiment for dataset {args.target_dataset} with {len(args.datasets)} completed at: {formatted_end_time}")
        print(f"[TIMING] Total duration: {formatted_duration} (H:M:S)")
        
        wandb.log({
            "experiment_start_time": formatted_start_time,
            "experiment_end_time": formatted_end_time,
            "experiment_duration_seconds": experiment_duration.total_seconds(),
            "experiment_duration_formatted": formatted_duration,
        })
    
    if is_main_process():
        wandb.finish(quiet=True)


if __name__ == "__main__":
    target_datasets = [
        "Cars",
        "DTD",
        "EuroSAT",
        "GTSRB",
        "MNIST",
        "RESISC45",
        "SUN397",
        "SVHN",
        "CIFAR10",
        "CIFAR100",
        "ImageNet",
        "STL10",
        "Food101",
        "Caltech101",
        "Caltech256",
        "FGVCAircraft",
        "Flowers102",
        "OxfordIIITPet",
        "CUB200",
        "PascalVOC",
        "Country211",
        "UCF101",
    ]

    args = parse_arguments()
    
    if not hasattr(args, 'no_commit'):
        args.no_commit = False
    
    print("=" * 80)
    print("AUTOMATIC EXPERIMENT TRACKING")
    print("-" * 50)
    print("This script automatically creates a git commit at the beginning of each run")
    print("to ensure that experiment code states are tracked.")
    print("Use --no-commit flag to skip this behavior if needed.")
    print("=" * 80)
    
    args.datasets = target_datasets
    args.lr = 1e-1
    args.epochs = 10
    args.batch_size = 64 if args.model == "ViT-L-14" else 128
    args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1

    args.print_every = 10
    
    args.seed = 0

    args.save = args.save + f"{args.model}"
    if args.subsample is not None:
        pass

    print("=" * 80)
    print("NOTE: When running multiple instances, use different ports to avoid conflicts")
    print("Example: python src/learn_coef.py --port=29501")
    print("Available port range: 29500-29509")
    print("=" * 80)

    with open(os.path.join(args.save, "zeroshot_accuracies.json"), 'r') as f:
        args.zs_acc = json.load(f)
    
    
    torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size)
