"""
Evaluates hyper-network on the test domains.
Uses unlabelled support images to build merged LoRA deltas.
Uses hooks to insert the merged LoRA deltas into the base model.
Returns per-domain and overall metrics.
"""

from experiments.hypernet_helpers import remove_all_hooks_, register_lora_hooks, generate_merge_deltas
from datasets.get_dataset import get_dataloader

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

@torch.inference_mode()
def test_hypernet_accuracy(
    config, trial,
    base_model: nn.Module,
    hypernet: nn.Module,
    source_models: dict,
    lora_sites: list[str],
    text_features: torch.Tensor,
    test_domains: list[str],
    preprocess,
    precomputed_deltas: dict
    ):
    """Evaluate a hyper-network that merges LoRA deltas into a base model.

    Parameters
    ----------
    config
        Experiment configuration with dataset/experiment fields.
    trial : int
        Seed for deterministic dataloader worker initialization.
    base_model : nn.Module
        Frozen backbone (CLIP). 
    hypernet : nn.Module
        Network that produces merge weights for the LoRA deltas.
    source_models : dict
        Mapping from training domain name to its source model.
    lora_sites : list[str]
        Ordered site names where LoRA deltas are applied.
    text_features : torch.Tensor
        Class text embeddings used to compute logits via similarity.
    test_domains : list[str]
        Domain identifiers to evaluate.
    preprocess
        Preprocessing transform applied to images.
    precomputed_deltas : dict
        Mapping of LoRA deltas indexed by site and domain.

    Returns
    -------
    dict
        Evaluation metrics. For WILDS datasets, returns WILDS metrics plus
        "test_loss"; for DomainNet, returns {"test_acc", "test_loss"}.
    """
    device = base_model.visual.conv1.weight.device
    dataset_name = config.dataset.name

    hypernet.eval()
    base_model.eval()

    total_correct = 0
    total_samples = 0
    total_loss = 0.0
    
    all_predictions = []
    all_labels = []
    all_metadata = []

    image_loss_func = nn.CrossEntropyLoss()

    for target_domain in tqdm(test_domains, desc="Testing domains", ncols=90):
        # ------------- build domain representation ----------------------
        support_dataset, support_loader = get_dataloader(
            dataset_name=dataset_name,
            trial_num=trial,
            path_to_annotations_dir=config.dataset.annotations_dir if dataset_name == "domainnet" else None,
            path_to_dataset_dir=config.dataset.data_dir,
            domains=[target_domain],
            classes=config.dataset.classes,
            transform=preprocess,
            split_type="test",
            batch_size=config.experiment.batch_size,
            num_workers=config.experiment.num_workers,
            pin_memory=config.experiment.pin_memory,
            is_training=True
        )
        
        support_images, _, _ = next(iter(support_loader))
        support_images = support_images.to(device)

        with torch.no_grad():
            s_feats = base_model.encode_image(support_images).float()
            s_feats = F.normalize(s_feats, dim=-1, eps=1e-6)
            domain_rep = s_feats.mean(dim=0)  # [domain_dim]

        train_domains = list(source_models.keys())
        delta_dict = generate_merge_deltas(
            hypernet=hypernet,
            deltas_by_site=precomputed_deltas,
            site_order=lora_sites,
            domain_order=train_domains,
            domain_representation=domain_rep,
            device=device        
        )

        test_dataset, test_loader = get_dataloader(
            dataset_name=dataset_name,
            trial_num=trial,
            path_to_annotations_dir=config.dataset.annotations_dir if dataset_name == "domainnet" else None,
            path_to_dataset_dir=config.dataset.data_dir,
            domains=[target_domain],
            classes=config.dataset.classes,
            transform=preprocess,
            split_type="test",
            batch_size=config.experiment.batch_size,
            num_workers=config.experiment.num_workers,
            pin_memory=config.experiment.pin_memory,
            is_training=False
        )

        domain_correct = 0
        domain_total = 0
        domain_loss = 0.0

        for images, labels, metadata in tqdm(test_loader, desc=f"→ {target_domain}", leave=False, ncols=80):
            # -------- compute loss with hooks --------------------------
            images, labels = images.to(device), labels.to(device)

            remove_all_hooks_(base_model)
            handles = register_lora_hooks(base_model, delta_dict)
            try:
                image_features = F.normalize(base_model.encode_image(images), dim=-1)
                logits = 100.0 * image_features @ text_features
                loss = image_loss_func(logits, labels)

                predictions = logits.argmax(-1)

            finally:
                for handle in handles:
                    handle.remove()

            domain_loss += loss.item()
            domain_correct += (predictions == labels).sum().item()
            domain_total += labels.size(0)

            all_predictions.append(predictions.cpu())
            all_labels.append(labels.cpu())
            all_metadata.append(metadata)

        # ---------- per‑domain metrics ---------------------------------
        domain_accuracy = 100.0 * domain_correct / domain_total
        domain_loss = domain_loss / len(test_loader)

        total_correct += domain_correct
        total_samples += domain_total
        total_loss += domain_loss

        print(f"[{target_domain}] accuracy={domain_accuracy:.2f}%  loss={domain_loss:.4f}")

    # ---------------- aggregated metrics -------------------------------
    all_predictions = torch.cat(all_predictions, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_metadata = torch.cat(all_metadata, dim=0)

    if dataset_name in config.constants.wilds_datasets:
        test_results, _ = test_dataset.eval(all_predictions, all_labels, all_metadata)
        test_results["test_loss"] = total_loss / len(test_domains)
    else:
        test_accuracy = total_correct / total_samples
        test_results = {"test_acc": test_accuracy, "test_loss": total_loss / len(test_domains)}
    
    hypernet.train()
    return test_results



