"""
Evaluate target-domain using baselines strategies to merge LoRA adapters.
Implements baseline strategies that pick or merge source-domain LoRA models
using a few unlabeled support images from the target domain, then evaluates
on the target domain:
- best_model: choose the single source with lowest entropy on support images
- avg_combination: uniform weight parameter-wise merging of all source LoRA adapters
- weighted_combination: entropy-weighted average of source LoRA adapters
"""

from __future__ import annotations
import random
from experiments.clipora.lora.inject import inject_linear_attention
from datasets.get_dataset import get_dataloader
from experiments.helpers import load_source_models, encode_text, load_lora_model
from experiments.hypernet_helpers import remove_all_hooks_, register_lora_hooks, _find_module

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from omegaconf import OmegaConf
import open_clip
import wandb    

import math
import torch
import torch.nn.functional as F
from tqdm import tqdm


def build_weighted_lora(
        lora_cfg,
        backbone: str,
        source_models: dict[str, torch.nn.Module],
        weights: dict[str, float],
        device: str = "cuda",
    ) -> torch.nn.Module:
    """Create a new LoRA model whose trainable params are a weighted average.

    Parameters
    ----------
    lora_cfg
        LoRA configuration used to instantiate a fresh LoRA model.
    backbone : str
        CLIP backbone identifier ("ViT-B-16").
    source_models : dict[str, torch.nn.Module]
        Mapping domain -> LoRA model trained on that source domain.
    weights : dict[str, float]
        Weights per domain; will be applied as linear combo.
    device : str
        Device to place the constructed LoRA model on.

    Returns
    -------
    torch.nn.Module
        The merged LoRA model in eval mode.
    """
    merged = load_lora_model(lora_cfg, backbone, device=device)
    state_dict = merged.state_dict()       

    with torch.no_grad():
        for name, p in merged.named_parameters():
            if not p.requires_grad:
                continue                  
            weighted = torch.zeros_like(p)
            for d, m in source_models.items():
                src = m.state_dict()[name]
                if src.device != weighted.device:
                    src = src.to(weighted.device)
                weighted += float(weights[d]) * src
            state_dict[name].copy_(weighted)

    merged.load_state_dict(state_dict)
    merged.eval()
    return merged

def compute_entropy(text_features, model, support_loader):
    """Compute predictive entropy of a model on support images.

    Parameters
    ----------
    text_features : torch.Tensor
        Class text embeddings for computing logits via similarity.
    model : torch.nn.Module
        Model used to compute image features.
    support_loader : DataLoader
        Loader over unlabeled support images for the target domain.

    Returns
    -------
    torch.Tensor
        Entropy of the model on the support images.
    """
    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        tf = text_features.to(device)
        for images, labels, metadata in tqdm(support_loader, desc="Computing entropy", ncols=90):
            images = images.to(device, non_blocking=True)
            image_features = F.normalize(model.encode_image(images), dim=-1)   
            logits = 100.0 * torch.matmul(image_features, tf)             
            log_probs = F.log_softmax(logits, dim=-1)                       
            probs = log_probs.exp()
            H = -(probs * log_probs).sum(dim=-1)                             
            break # only compute entropy for one batch (few-shot)
    return H
    
def test_baselines(config, trial, train_domains, test_domains):
    """Evaluate baseline strategies on target domains using support images.

    Strategies
    ---------
    - best_model: select source model with lowest entropy on support images
    - avg_combination: uniform average of LoRA params from all source adapters
    - weighted_combination: entropy-weighted average of LoRA adapters

    Returns
    -------
    dict
        For WILDS datasets: official WILDS metrics plus averaged test loss.
        For DomainNet: {"test_acc", "test_loss"}.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset_name, backbone = config.dataset.name, config.backbone
    wandb.init(
        entity  = config.constants.wandb_entity,
        project = config.experiment.name,
        name    = f"{dataset_name}_{test_domains[0]}_lr{config.experiment[dataset_name].lr_clip}",
        config  = OmegaConf.to_container(config, resolve=True, throw_on_missing=False),
    )

    print("--------------------------------")
    print("Experiment: ", config.experiment)
    print("Train domains: ", train_domains)
    print("Test domains: ", test_domains)
    print("Dataset name: ", dataset_name)
    print("Backbone: ", backbone)
    print("--------------------------------")

    base_model, _, preprocess = open_clip.create_model_and_transforms(
        backbone, pretrained="openai")
    base_model = inject_linear_attention(
        base_model, {"visual.transformer"},
        embed_dim=base_model.visual.transformer.width, num_heads=12)
    base_model.to(device)
    for p in base_model.parameters():
        p.requires_grad = False

    tokenizer = open_clip.get_tokenizer(backbone)
    text_features = encode_text(
        config.dataset.classes, config.dataset.prompt_templates,
        base_model, tokenizer, device)  

    source_models = load_source_models(config.experiment, train_domains, dataset_name, config.constants.model_save_path, backbone, trial, device="cpu")

    total_correct = 0
    total_samples = 0
    total_loss = 0.0
    
    all_predictions = []
    all_labels = []
    all_metadata = []

    image_loss_func = nn.CrossEntropyLoss()

    for target_domain in tqdm(test_domains, desc="Testing domains", ncols=90):

        support_dataset, support_loader = get_dataloader(
            dataset_name=dataset_name,
            trial_num=trial,
            path_to_annotations_dir=config.dataset.annotations_dir if dataset_name == "domainnet" else None,
            path_to_dataset_dir=config.dataset.data_dir,
            domains=[target_domain],
            classes=config.dataset.classes,
            transform=preprocess,
            split_type="test",
            batch_size=config.experiment.batch_size,
            num_workers=config.experiment.num_workers,
            pin_memory=config.experiment.pin_memory,
            is_training=True
        )

        strategy = config.experiment.strategy

        if strategy == "best_model":
            entropies = {d: compute_entropy(text_features, source_models[d], support_loader) for d in train_domains}       
            best_model = min(entropies, key=entropies.get)         
            print("Best source model for ", test_domains, ":", best_model)                                    
            target_model = source_models[best_model]

        elif strategy == "avg_combination":
            weights = {d: 1 / len(train_domains) for d in train_domains}   
            final_weights = {n: torch.zeros_like(p) for n, p in next(iter(source_models.values())).named_parameters()} 
            for d, model in source_models.items():  
                for n, p in model.named_parameters():      
                    final_weights[n] += p * weights[d]       
            target_model = load_lora_model(config.experiment, backbone, device)      
            target_model.load_state_dict(final_weights)

        elif strategy == "weighted_combination":
            entropies = {d: compute_entropy(text_features, source_models[d], support_loader) for d in train_domains}
            weights = {d: torch.exp(-torch.tensor(e)) for d, e in entropies.items()}
            total_weight = sum(weights.values())
            norm_weights = {d: w / total_weight for d, w in weights.items()}
            target_model = build_weighted_lora(config.experiment, config.backbone, source_models, norm_weights, str(device))
            target_model.to(device)

        test_dataset, test_loader = get_dataloader(
            dataset_name=dataset_name,
            trial_num=trial,
            path_to_annotations_dir=config.dataset.annotations_dir if dataset_name == "domainnet" else None,
            path_to_dataset_dir=config.dataset.data_dir,
            domains=[target_domain],
            classes=config.dataset.classes,
            transform=preprocess,
            split_type="test",
            batch_size=config.experiment.batch_size,
            num_workers=config.experiment.num_workers,
            pin_memory=config.experiment.pin_memory,
            is_training=False
        )

        domain_correct = 0
        domain_total = 0
        domain_loss = 0.0

        target_model.eval()

        with torch.no_grad():
            for images, labels, metadata in tqdm(test_loader, desc=f"→ {target_domain}", leave=False, ncols=80):
                
                images, labels = images.to(device), labels.to(device)
                image_features = F.normalize(target_model.encode_image(images), dim=-1)
                logits = 100.0 * image_features @ text_features
                loss = image_loss_func(logits, labels)
                predictions = logits.argmax(-1)

                domain_loss += loss.item()
                domain_correct += (predictions == labels).sum().item()
                domain_total += labels.size(0)

                all_predictions.append(predictions.cpu())
                all_labels.append(labels.cpu())
                all_metadata.append(metadata)

        domain_accuracy = 100.0 * domain_correct / domain_total
        domain_loss = domain_loss / len(test_loader)

        total_correct += domain_correct
        total_samples += domain_total
        total_loss += domain_loss

        print(f"[{target_domain}] accuracy={domain_accuracy:.2f}%  loss={domain_loss:.4f}")

    all_predictions = torch.cat(all_predictions, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_metadata = torch.cat(all_metadata, dim=0)

    if dataset_name in config.constants.wilds_datasets:
        test_results, _ = test_dataset.eval(all_predictions, all_labels, all_metadata)
        test_results["test_loss"] = total_loss / len(test_domains)
    else:
        test_accuracy = total_correct / total_samples
        test_results = {"test_acc": test_accuracy, "test_loss": total_loss / len(test_domains)}

    print("--------------------------------")
    print("Test results for ", test_domains, ":", test_results)
    print("--------------------------------")
    
    return test_results
