from collections import defaultdict
from copy import deepcopy
import torch
import numpy as np
import random
import open_clip
from torch.utils.data import Subset
from src.dataset.eurosat import EuroSat
from src.dataset.cifar100 import CIFAR100
from src.dataset.sun397 import SUN397
from src.dataset.cars import Cars
from src.dataset.dtd import DTD
from src.dataset.svhn import SVHN
from src.dataset.gtsrb import GTSRB
from src.dataset.resisc45 import RESISC45
from src.dataset.imagenet_r import IMAGENETR
from torch.utils.data import DataLoader
import argparse
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
import wandb 
import torchvision
from torchvision.transforms import Normalize
import logging, os, pickle
from sklearn.model_selection import train_test_split
logger = logging.getLogger(__name__)

class SubsetCustom(Subset):
    def __init__(self, dataset, indices):
        super().__init__(dataset, indices)
        
    @property
    def targets(self):
        if hasattr(self.dataset, "labels"):   
            return np.array(self.dataset.labels)[self.indices]
        elif hasattr(self.dataset, "targets"): 
            return np.array(self.dataset.targets)[self.indices]
        else:
            return None
    
    @property
    def class_names(self):
        return getattr(self.dataset, 'class_names', None)
    
    @property
    def templates(self):
        return getattr(self.dataset, 'templates', None)
    
    def single_template(self, *args, **kwargs):
        if hasattr(self.dataset, 'single_template'):
            return self.dataset.single_template(*args, **kwargs)
        return None
        
    def __getattr__(self, attr):
        if attr in ['dataset', 'indices']:
            return super().__getattribute__(attr)
        if attr in ['class_names', 'templates', 'single_template']:
            return getattr(self.dataset, attr)
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

def create_few_shot_subset(dataset, images_per_class, dataset_name, save_dir="few_shot_indices"):
    """Create a few-shot subset of a dataset.
    
    Args:
        dataset: PyTorch dataset with (image, label) format
        images_per_class (int): Number of images per class to sample
        dataset_name (str): Name of the dataset for file naming
        save_dir (str): Directory to save/load indices files
        
    Returns:
        Subset: Subset of the original dataset with few-shot samples
    """
    
    os.makedirs(save_dir, exist_ok=True)
    
    indices_file = os.path.join(save_dir, f"{dataset_name}_indices_{images_per_class}_per_class.pkl")
    
    if os.path.exists(indices_file):
        logger.info(f"Loading existing indices from {indices_file}")
        with open(indices_file, 'rb') as f:
            sampled_indices = pickle.load(f)
    else:
        logger.info(f"Creating new indices for {dataset_name} with {images_per_class} images per class")
        class_indices = defaultdict(list)
        for idx, (_, label) in enumerate(dataset):
            class_indices[label].append(idx)
        
        sampled_indices = []
        for indices in class_indices.values():
            sampled_indices.extend(random.sample(indices, min(images_per_class, len(indices))))
        
        with open(indices_file, 'wb') as f:
            pickle.dump(sampled_indices, f)
        logger.info(f"Saved indices to {indices_file}")
    
    return SubsetCustom(dataset, sampled_indices)

def parse_arguments():
    """Parse and return command-line arguments for the project.

    Returns:
        argparse.Namespace: Parsed arguments including model, dataset, training,
        distillation and logging options.
    """
    
    parser = argparse.ArgumentParser(description='CLIP task vector analysis')
    parser.add_argument('--seed', default=54, type=int, help='Random seed for reproducibility')
    parser.add_argument('--arch', default='ViT-B-16', type=str, help='Model architecture.')
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size for evaluation.')
    parser.add_argument('--workers', default=4, type=int, help='Number of data loading workers.')
    parser.add_argument('--pretraining_backbone_A', default='commonpool_l_s1b_b8k', type=str, help='Pretraining model A for backbone1.')
    parser.add_argument('--pretraining_backbone_B', default='laion400m_e32', type=str, help='Pretraining model B for backbone2.')
    parser.add_argument('--base_folder', default="./", type=str, help='Path of base folder')
    parser.add_argument('--wandb_mode', default='disabled', type=str, choices=['online', 'offline', 'disabled'], help='Wandb mode')
    parser.add_argument('--dataset', default='eurosat', type=str, help='Dataset to use')
    parser.add_argument('--wandb_group', default='dataset distillation', type=str, help='Wandb group name')
    parser.add_argument('--wandb_run_name', default=None, type=str, help='Wandb run name (default: <pretraining_backbone_A>_<pretraining_backbone_B>_<dataset>)')
    
    # gradient distillation
    parser.add_argument('--distill_lr', default=0.01, type=float, help='Learning rate for distillation')
    parser.add_argument('--model_lr', default=1e-5, type=float, help='Learning rate for model training')
    parser.add_argument('--conv_lr', default=1e-3, type=float, help='Learning rate for convolutions')
    parser.add_argument('--num_synthetic', default=20, type=int, help='Number of synthetic images to use')
    parser.add_argument('--iterations', default=1, type=int, help='Number of distillation steps')
    parser.add_argument('--initialization', default='sample', type=str, choices=['random', 'mean', 'sample', 'coreset', 'herding', 'k-medoid'], help='Initialization method for synthetic images')
    parser.add_argument('--batch_per_class', default=2, type=int, help='Number of real images per class for distillation')
    parser.add_argument('--epochs', default=1, type=int, help='Number of epochs for distillation')
    parser.add_argument('--loss_type', default='cosine_distance', type=str, choices=['mse_full', 'mse_sign', 'cosine_distance', 'magnitude_weighted_sign'], help='Loss type for gradient matching')
    parser.add_argument('--optimizer', default='sgd', type=str, choices=['adam', 'sgd'], help='Optimizer for distillation')
    parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='Number of gradient accumulation steps')
    parser.add_argument('--perturbation_eps', default=None, type=float, help='Perturbation of theta during distillation')
    parser.add_argument('--train_model', action='store_true', help='Train the model during distillation')
    parser.add_argument('--reinit_epoch_freq', default=10, type=int, help='Frequency of reinitializing the model optimizer')
    parser.add_argument('--synth_mode', type=str, choices=['direct', 'conv', 'residual', 'residual_y', 'residual_pure', 'unet'], default='direct', help='Mode of synthetic image generation')

    parser.add_argument('--eval_alphas', default=10, type=int, help='Number of alpha values to evaluate')
    parser.add_argument('--save_dir', default='synthetic_images', type=str, help='Directory to save synthetic images')
    parser.add_argument('--dsa', action='store_true', help='Use DSA augmentation')
    parser.add_argument('--dsa_strategy', default='color_flip_rotate', type=str, help='DSA strategy to use')
    parser.add_argument('--save_every_n_epochs', type=int, default=0, help='Save distilled images during training every n epochs')
    parser.add_argument('--soup', action='store_true', help='Evaluate directly the soup of the images')

    parser.add_argument('--finetuned_checkpoint_A', default=None,
                        type=str, help='Path to finetuned model A. If not set, defaults to base_folder/clip-finetuned-weights/<dataset>/<arch>/<pretraining_backbone_A>/best.pt')
    parser.add_argument('--finetuned_checkpoint_B', default=None,
                        type=str, help='Path to finetuned model B. If not set, defaults to base_folder/clip-finetuned-weights/<dataset>/<arch>/<pretraining_backbone_B>/best.pt')
    
    parser.add_argument('--synth_batch_size', default=None, type=int, help='Batch size for synthetic images in distillation, define how many classes per batch to optimize (None=all)')
    
    parser.add_argument('--images_per_class', default=None, type=int, help='Number of images per class for few-shot learning (None=use full dataset)')
    
    
    args = parser.parse_args()
    if args.finetuned_checkpoint_A is None:
        args.finetuned_checkpoint_A = f"{args.base_folder}/clip-finetuned-weights/{args.dataset}/{args.arch}/{args.pretraining_backbone_A}/best.pt"
    if args.finetuned_checkpoint_B is None:
        args.finetuned_checkpoint_B = f"{args.base_folder}/clip-finetuned-weights/{args.dataset}/{args.arch}/{args.pretraining_backbone_B}/best.pt"
    return args

def set_seed(seed):
    """
    Set the random seed for reproducibility.
    
    Args:
        seed (int): The seed value to set for random number generation.

    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def compare_signs(tb_signs, gradient_signs):
    total_params = 0
    matching_signs = 0
    per_param_sign_agreement = {}

    for name, tb_sign in tb_signs.items():
        if name in gradient_signs:
            grad_sign = gradient_signs[name]
            match = (tb_sign == grad_sign).sum().item()
            total = tb_sign.numel()
            agreement = match / total * 100

            per_param_sign_agreement[name] = agreement
            matching_signs += match
            total_params += total

    overall_sign_agreement = matching_signs / total_params * 100
    return overall_sign_agreement, per_param_sign_agreement

def evaluate_with_task_vector(
    base_model,
    task_vector,
    alpha,
    test_dataloader,
    test_dataset,
    device,
    vector_name,
    best_acc,
    best_alpha,
    results_list,
    logger,
    display_name=None,
    metric_prefix=""
):
    """
    Evaluates a model with a given task vector applied at a specific scaling coefficient (alpha).
    
    Args:
        base_model: Base model to apply the task vector to
        task_vector: TaskVector to apply to the model
        alpha: Scaling coefficient for the task vector
        test_dataloader: DataLoader for test data
        test_dataset: Test dataset
        device: Device to run evaluation on
        vector_name: Name of the task vector (for logging)
        best_acc: Current best accuracy for this vector type
        best_alpha: Current best alpha for this vector type
        results_list: List to append results to
        logger: Logger object for info logging
        display_name: Optional display name (defaults to vector_name if None)
    
    Returns:
        tuple: (updated best accuracy, updated best alpha)
    """
    if display_name is None:
        display_name = vector_name
        
    model_b_t = deepcopy(base_model)
    model_b_t.visual.load_state_dict(task_vector.apply_to(base_model.visual, scaling_coef=alpha).state_dict())
    loss, acc = evaluate_model(model_b_t, test_dataloader, test_dataset, device, prompt_ensemble=True)
    logger.info(f"Model {metric_prefix} + {display_name} | TASK : {acc}, loss {loss}")
    
    new_best_acc = best_acc
    new_best_alpha = best_alpha
    if acc > best_acc:
        new_best_acc = acc
        new_best_alpha = alpha
    
    # Store results
    results_list.append({
        "alpha": alpha,
        "loss": loss,
        "accuracy": acc
    })
    
    wandb.log({
        "alpha": alpha,
        f"{metric_prefix}{vector_name}_loss": loss,
        f"{metric_prefix}{vector_name}_accuracy": acc
    })
    
    return new_best_acc, new_best_alpha


def accuracy(output, target, topk=(1, 5)):
    """Compute top-k accuracy.

    Args:
        output (np.ndarray): Model outputs as probabilities or logits (N x C) or binary vector.
        target (np.ndarray): Ground truth labels (N,).
        topk (Tuple[int, ...]): Values of k to compute accuracy for.

    Returns:
        float: Top-1 accuracy (or binary accuracy if output is 1-D).
    """
    if len(output.shape) == 1:
        acc = np.sum((output >= 0.5).astype(float) == target)/target.shape[0]
        return acc.item()
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.shape[0]

        _, pred = torch.from_numpy(output).topk(maxk, dim=1)
        target = torch.from_numpy(target).view(batch_size, 1).repeat(1, maxk)
        
        correct = (pred == target)
  
        topk_accuracy = []
        for k in topk:
            accuracy = correct[:, :k].float().sum().item()
            accuracy /= batch_size # [0, 1.]
            topk_accuracy.append(accuracy)
        
        return topk_accuracy[0]
    
def setup_environment(args):
    """Setup random seeds and compute device from arguments.

    Args:
        args (argparse.Namespace): Parsed CLI arguments.

    Returns:
        str: Device string (e.g., 'cuda:0' or 'cpu').
    """
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    return device

def get_models(args, device):
    """Create OpenCLIP backbones and load finetuned checkpoints when available.

    Args:
        args (argparse.Namespace): CLI arguments with architecture and checkpoint info.
        device (str): Device to instantiate models on.

    Returns:
        Tuple[open_clip.CLIP, open_clip.CLIP, open_clip.CLIP, open_clip.CLIP, torchvision.Transform]:
        Backbone A, Backbone B, finetuned A, finetuned B, and preprocess transform.
    """
    backbone_A, _, preprocess_A = open_clip.create_model_and_transforms(args.arch,
                                                                      pretrained=args.pretraining_backbone_A,
                                                                      cache_dir=f'{args.base_folder}/open_clip',
                                                                      device=device)
    backbone_B, _, preprocess_B = open_clip.create_model_and_transforms(args.arch,
                                                             pretrained=args.pretraining_backbone_B,
                                                             cache_dir=f'{args.base_folder}/open_clip',
                                                             device=device)
    try:
        state_dict = torch.load(args.finetuned_checkpoint_A)['model_state_dict']
        model_A_ft = deepcopy(backbone_A)
        model_A_ft.load_state_dict(state_dict)
        state_dict = torch.load(args.finetuned_checkpoint_B)['model_state_dict']
        model_B_ft = deepcopy(backbone_B)
        model_B_ft.load_state_dict(state_dict)
    except FileNotFoundError:
        logger.warning(
            f"Finetuned checkpoint {args.finetuned_checkpoint_A} not found. Using backbone A as model A ft.")
        model_A_ft = deepcopy(backbone_A)
        logger.warning(
            f"Finetuned checkpoint {args.finetuned_checkpoint_B} not found. Using backbone B as model B ft.")
        model_B_ft = deepcopy(backbone_B)

    return backbone_A, backbone_B, model_A_ft, model_B_ft, preprocess_A, preprocess_B

def get_normalize_mean_std(preprocess):
    """Extract mean and std from a Normalize transform inside preprocess.

    Args:
        preprocess: A torchvision transform pipeline containing Normalize.

    Returns:
        Tuple[Tuple[float, ...], Tuple[float, ...]]: (mean, std)
    """
    for t in preprocess.transforms:
        if isinstance(t, Normalize):
            return (t.mean, t.std)
    raise ValueError("Normalize not found in preprocess")

class SubsetWithAttrs(Subset):
    @property
    def targets(self):
        if hasattr(self.dataset, "labels"):
            return np.array(self.dataset.labels)[self.indices]
        elif hasattr(self.dataset, "targets"):
            return np.array(self.dataset.targets)[self.indices]
        else:
            return None
        
    def __getattr__(self, attr):
        return getattr(self.dataset, attr)
    
def get_validation_split(dataset, val_ratio=0.2, val_seed=42):
    """Split a dataset into training and validation subsets.

    Args:
        dataset: Dataset to split.
        val_ratio (float): Proportion of data to use for validation.
        val_seed (int): Random seed for reproducibility.

    Returns:
        Tuple[Subset, Subset]: (train_subset, val_subset)
    """
    targets = dataset.targets
    indices = np.arange(len(targets))
    train_idx, val_idx = train_test_split(
        indices,
        test_size=val_ratio,
        stratify=targets,
        random_state=val_seed
    )
    train_dataset = SubsetWithAttrs(dataset, train_idx)
    val_dataset = SubsetWithAttrs(dataset, val_idx)
    return train_dataset, val_dataset

def load_dataset(args, preprocess, support=False, validation=False, images_per_class=None):
    """Load datasets and dataloaders for training/testing or support/target.

    Args:
        args (argparse.Namespace): CLI arguments controlling dataset choice and paths.
        preprocess: Transform to apply to images.
        support (bool): If True, return (target, support) dataloaders/datasets.
        validation (bool): If True, return validation dataloaders/datasets.
        images_per_class (int, optional): If provided, create few-shot subset with this many images per class.
    """
    val_ratio = 0.2
    val_seed = 42
    print("Loading dataset...")
    if args.dataset == 'cifar100':
        train_dataset = CIFAR100(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_dataset = CIFAR100(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            train_dataset, val_dataset = get_validation_split(train_dataset, val_ratio=val_ratio, val_seed=val_seed)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 
    elif args.dataset == 'eurosat':
        test_dataset = EuroSat(root=f"{args.base_folder}/datasets/eurosat", split='test', transform=preprocess)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False)
        train_dataset = EuroSat(root=f'{args.base_folder}/datasets/eurosat', split='train', transform=preprocess)
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
        
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.workers)
        if validation:
            val_dataset = EuroSat(root=f"{args.base_folder}/datasets/eurosat", split='val', transform=preprocess)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False)
    elif args.dataset == 'sun397':
        train_dataset = SUN397(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_dataset = SUN397(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            train_dataset, val_dataset = get_validation_split(train_dataset, val_ratio=val_ratio, val_seed=val_seed)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    elif args.dataset == 'cars':
        train_dataset = Cars(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_dataset = Cars(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            train_dataset, val_dataset = get_validation_split(train_dataset, val_ratio=val_ratio, val_seed=val_seed)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    elif args.dataset == 'dtd':
        train_dataset = DTD(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_dataset = DTD(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            val_dataset = DTD(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size, validation=True)
            val_loader = val_dataset.val_loader
    elif args.dataset == 'svhn':
        train_dataset = SVHN(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_dataset = SVHN(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            train_dataset, val_dataset = get_validation_split(train_dataset, val_ratio=val_ratio, val_seed=val_seed)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    elif args.dataset == 'gtsrb':
        train_dataset = GTSRB(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_dataset = GTSRB(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            train_dataset, val_dataset = get_validation_split(train_dataset, val_ratio=val_ratio, val_seed=val_seed)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    elif args.dataset == 'resisc45':
        train_dataset = RESISC45(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        train_loader = train_dataset.train_loader
        test_dataset = RESISC45(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers, batch_size=args.batch_size)
        test_loader = test_dataset.test_loader
        
        # Apply few-shot subset if requested
        if images_per_class is not None:
            train_dataset = create_few_shot_subset(train_dataset, images_per_class, args.dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        
        if validation:
            train_dataset, val_dataset = get_validation_split(train_dataset, val_ratio=val_ratio, val_seed=val_seed)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    else:
        raise ValueError(f"Invalid dataset: {args.dataset}")
    
    if support:
        support_dataset = IMAGENETR(preprocess=preprocess, location=f'{args.base_folder}/datasets', num_workers=args.workers)
        support_loader = support_dataset.test_loader
        print(f'Number of support samples: {len(support_loader.dataset)}')
        print(f'Number of train samples: {len(train_loader.dataset)}')
        print(f'Number of test samples: {len(test_loader.dataset)}')
        if validation:
            print(f'Number of val samples: {len(val_loader.dataset)}')
            return train_loader, test_loader, val_loader, train_dataset, test_dataset, val_dataset, support_loader, support_dataset
        else:
            return train_loader, test_loader, None,  train_dataset, test_dataset, None, support_loader, support_dataset
    else:
        print(f'Number of train samples: {len(train_loader.dataset)}')
        print(f'Number of test samples: {len(test_loader.dataset)}')
        if validation:
            print(f'Number of val samples: {len(val_loader.dataset)}')
            return train_loader, test_loader, val_loader, train_dataset, test_dataset, val_dataset, None, None
        else:
            return train_loader, test_loader, None, train_dataset, test_dataset, None, None, None

    



def evaluate_model(model, dataloader, dataset, device='cuda:0', prompt_ensemble=True, first_n_batches=None, disable_bar=True):
    """Evaluate a CLIP-like model on a dataset.

    Args:
        model: Model exposing encode_image/encode_text or get_logits.
        dataloader (DataLoader): Dataloader for evaluation.
        dataset: Dataset object with templates and class_names.
        device (str): Device string.
        prompt_ensemble (bool): If True, average over all templates; else use single template.
        first_n_batches (int | None): If set, evaluate only first N batches.
        disable_bar (bool): Disable tqdm progress bar.

    Returns:
        Tuple[float, float]: (average loss, top-1 accuracy).
    """
    
    eval_avg_loss = 0
    all_probs = []
    all_labels = []
    ce_loss = nn.CrossEntropyLoss()
    
    model.eval()
    
    if prompt_ensemble:
        # prompts =  [[template(c) for c in cifar.class_names] for template in cifar.templates] #cifar100
        prompts =  [[template(c.lower()) for c in dataset.class_names] for template in dataset.templates]
        with torch.no_grad():
            template_embeddings = []
            for template in prompts:
                test_texts = open_clip.tokenize(template)
                test_texts = test_texts.to(device)
                test_text_features = F.normalize(model.encode_text(test_texts), dim=-1)
                template_embeddings.append(test_text_features)
                
            text_features = torch.mean(torch.stack(template_embeddings), dim=0)
    else: 
        prompts = [dataset.single_template(c.lower()) for c in dataset.class_names]
 
        with torch.no_grad():
            test_texts = open_clip.tokenize(prompts)
            test_texts = test_texts.to(device)
            text_features = F.normalize(model.encode_text(test_texts), dim=-1)
    for id, batch in tqdm(enumerate(dataloader), disable=disable_bar) :
        if first_n_batches is not None:
            if id == first_n_batches:
                break
        images, targets = batch 

        images= images.to(device)
        
        targets = targets.to(device)

        targets = targets.long() #fix resisc45
        
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = F.normalize(model.encode_image(images), dim=-1)
            vl_logits = model.logit_scale.exp() * (torch.einsum('ij,cj->ic',image_features, text_features))
            
        vl_prob = torch.softmax(vl_logits.float(), dim=-1)
        
        all_probs.append(vl_prob.cpu().numpy())
        all_labels.append(targets.cpu().numpy())
        loss = ce_loss(vl_logits, targets)
        
        eval_avg_loss += loss.item()

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    eval_avg_loss /= len(dataloader)

    overall_acc = accuracy(all_probs, all_labels, topk=(1,))
    return eval_avg_loss, overall_acc

def evaluate_target_and_support(model, dataloaders: list, datasets: list, device, prompt_ensemble=True) -> dict:
    results = {}
    for dataloader, dataset in zip(dataloaders, datasets):
        loss, accuracy = evaluate_model(model, dataloader, dataset, device, prompt_ensemble)
        results[dataset.__class__.__name__] = (loss, accuracy)
    return results

