"""
Evaluate source (or base) CLIP models on specified domains.
Reports accuracy and loss for DomainNet.
Reports official WILDS metrics and loss for WILDS datasets (IWildCam, Camelyon17, FMOW).
"""

import torch.nn as nn
import torch
import open_clip
from datasets.get_dataset import get_dataloader
from experiments.helpers import encode_text
from tqdm import tqdm
from omegaconf import DictConfig
from omegaconf import OmegaConf

def eval_model(
        cfg: DictConfig,
        trial_num: int,
        model: torch.nn.Module,
        test_domains: list,
        split_type: str,
        create_val_from_train: bool
    ): 
    """Evaluate a CLIP model (or base model) on target domains.

    Parameters
    ----------
    cfg : DictConfig
        Experiment configuration with dataset and experiment fields.
    trial_num : int
        Seed for deterministic dataloader worker initialization.
    model : torch.nn.Module
        If provided, the CLIP model to evaluate; otherwise uses the base model.
    test_domains : list
        Domain identifiers to evaluate.
    split_type : str
        Dataset split to use (e.g., "test", "val").
    create_val_from_train : bool
        If True and split is train (WILDS), creates a validation holdout.

    Returns
    -------
    dict
        If WILDS dataset: returns WILDS metrics plus "test_loss".
        Else: returns {"test_acc", "test_loss"}.
    """

    dataset_name = cfg.dataset.name
    backbone = cfg.backbone
    print("Backbone: ", backbone)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    base_model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(backbone, pretrained="openai")
    tokenizer = open_clip.get_tokenizer(backbone)
    base_model.to(device)

    test_dataset, test_loader = get_dataloader(
        dataset_name=dataset_name,
        trial_num=trial_num,
        path_to_annotations_dir=cfg.dataset.annotations_dir if dataset_name == "domainnet" else None,
        path_to_dataset_dir=cfg.dataset.data_dir,
        domains=test_domains,
        classes=cfg.dataset.classes,
        transform=preprocess_val,
        split_type=split_type,
        batch_size=cfg.experiment.batch_size,
        num_workers=cfg.experiment.num_workers,
        pin_memory=cfg.experiment.pin_memory,
        is_training=False,
        create_val_from_train=create_val_from_train 
    )

    text_features = encode_text(
        cfg.dataset.classes, 
        cfg.dataset.prompt_templates, 
        base_model, 
        tokenizer, 
        device
    )

    if model is not None:
        clip_model = model
    else:
        clip_model = base_model

    clip_model.to(device)

    loss_func = nn.CrossEntropyLoss()    
    
    clip_model.eval()

    with torch.no_grad():

        all_preds = []
        all_labels = []
        all_metadata = []

        num_correct = 0
        num_total = 0
        total_loss = 0

        for images, labels, metadata in tqdm(test_loader):

            images, labels, metadata = images.to(device), labels.to(device), metadata.to(device)

            image_features = clip_model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            similarity_scores = clip_model.logit_scale.exp() * torch.matmul(image_features, text_features)
            
            loss = loss_func(similarity_scores, labels)
            total_loss += loss.item()

            preds = torch.argmax(similarity_scores, dim=1)

            num_correct += (preds == labels).sum().item()
            num_total += len(labels)

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
            all_metadata.append(metadata.cpu())
            
        all_preds = torch.cat(all_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        all_metadata = torch.cat(all_metadata, dim=0)
        test_accuracy = num_correct / num_total
        test_loss = total_loss / len(test_loader)

        if (dataset_name in cfg.constants.wilds_datasets):
            eval_results, _ = test_dataset.eval(all_preds, all_labels, all_metadata)
            eval_results["test_loss"] = test_loss
        else:
            eval_results = {"test_acc": test_accuracy, "test_loss": test_loss}

        print("--------------------------------")
        print("Test accuracy for ", test_domains, ":", test_accuracy)
        print("--------------------------------")

    return eval_results

