"""Train a hypernetwork to merge source-domain LoRA adapters.
This script meta-trains a hypernetwork (cross-attention or MLP) that predicts
merge weights over source-domain LoRA deltas conditioned on a target-domain
representation (computed from support images). The predicted weights are used
to merge LoRA deltas and apply them via hooks over a frozen CLIP backbone.
The merged model is trained using query batches from the pseudo-target domain.
"""

from __future__ import annotations
import random
import os
from experiments.clipora.lora.inject import inject_linear_attention
from datasets.get_dataset import get_dataloader
from experiments.helpers import load_source_models, encode_text
from experiments.hypernet_helpers import remove_all_hooks_, register_lora_hooks, _find_module
from experiments.hypernet_helpers import precompute_lora_deltas, generate_merge_deltas, compute_domain_weights
from experiments.hypernet_helpers import render_domain_weight_heatmaps
from experiments.eval_hypernet import test_hypernet_accuracy

from experiments.hypernet_cross_attn import CrossAttnHyperNet
from experiments.hypernet_mlp import MLPHyperNet

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from contextlib import nullcontext
from omegaconf import OmegaConf
import open_clip
import wandb    
from torch.optim import AdamW
from collections import defaultdict

# -------------------------------------------------------------------------
#  Create data iterators for each source domain
# -------------------------------------------------------------------------

def create_data_iters(config, seed, preprocess, train_domains, dataset_name, generator=None):
    """Create per-domain DataLoader iterators for meta-training.

    Parameters
    ----------
    config : DictConfig
        Hydra config with dataset paths and dataloader options.
    seed : int
        Seed used to initialize or advance the torch.Generator.
    preprocess : callable
        Transform for the dataset images.
    train_domains : list[str]
        Source domains used during meta-training.
    dataset_name : str
        Dataset identifier used by get_dataloader.
    generator : Optional[torch.Generator]
        If provided, used for deterministic shuffling; otherwise a new one is
        created with the given seed.

    Returns
    -------
    tuple(dict, dict, dict)
        (datasets, data_loaders, data_iters) keyed by domain name.
    """
    data_loaders = {}
    datasets = {}
    
    if generator is None:
        generator = torch.Generator()
        generator.manual_seed(seed)
    
    for pseudo_target_domain in train_domains:
        train_dataset, train_loader = get_dataloader(
            dataset_name=dataset_name,
            trial_num=int(generator.initial_seed()),  
            path_to_annotations_dir=config.dataset.annotations_dir if dataset_name == "domainnet" else None,
            path_to_dataset_dir=config.dataset.data_dir,
            domains=[pseudo_target_domain],
            classes=config.dataset.classes,
            transform=preprocess,
            split_type="train",
            batch_size=config.experiment.batch_size,
            num_workers=config.experiment.num_workers,
            pin_memory=config.experiment.pin_memory,
            is_training=True
        )
        datasets[pseudo_target_domain] = train_dataset
        data_loaders[pseudo_target_domain] = train_loader

    data_iters = {d: iter(data_loaders[d]) for d in train_domains}

    return datasets, data_loaders, data_iters

# -------------------------------------------------------------------------
# Train hypernetwork (uses hook‑based LoRA injection)
# -------------------------------------------------------------------------

def train_hypernet(config, trial, train_domains, test_domains):
    """Meta-train a hypernetwork and evaluate on target domains.

    Parameters
    ----------
    config : DictConfig
        Hydra configuration with experiment, dataset, constants, and backbone.
    trial : int
        Trial index used for seeding and checkpoint naming.
    train_domains : list[str]
        Source domains available during meta-training.
    test_domains : list[str]
        Held-out target domains for testing.

    Returns
    -------
    dict
        Dictionary of test metrics.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset_name, backbone = config.dataset.name, config.backbone
    wandb.init(
        entity  = config.constants.wandb_entity,
        project = config.experiment.name,
        name    = f"{config.experiment.hypernet_type}_{dataset_name}_{test_domains[0]}_lr{config.experiment[dataset_name].lr_clip}",
        config  = OmegaConf.to_container(config, resolve=True, throw_on_missing=False),
    )

    print("--------------------------------")
    print("Device: ", device)
    print("Experiment: ", config.experiment)
    print("Train domains: ", train_domains)
    print("Test domains: ", test_domains)
    print("Dataset name: ", dataset_name)
    print("Backbone: ", backbone)
    print("--------------------------------")

    base_model, train_preprocess, val_preprocess = open_clip.create_model_and_transforms(
        backbone, pretrained="openai")
    base_model = inject_linear_attention(
        base_model, {"visual.transformer"},
        embed_dim=base_model.visual.transformer.width, num_heads=12)
    base_model.to(device)
    for p in base_model.parameters():
        p.requires_grad = False
    base_model.eval()

    source_models = load_source_models(config.experiment, train_domains, dataset_name, config.constants.model_save_path, backbone, trial, device)

    tokenizer = open_clip.get_tokenizer(backbone)

    generator = torch.Generator()
    generator.manual_seed(trial)
    
    datasets, data_loaders, data_iters = create_data_iters(
        config=config,
        seed=trial,
        preprocess=train_preprocess,
        train_domains=train_domains,
        dataset_name=dataset_name,
        generator=generator
    )

    text_features = encode_text(
        config.dataset.classes, config.dataset.prompt_templates,
        base_model, tokenizer, device)                  
    text_features = text_features.float()

    n_blocks  = len(base_model.visual.transformer.resblocks)
    lora_sites = ([f"visual.transformer.resblocks.{i}.attn.qkv"  for i in range(n_blocks)] +
                  [f"visual.transformer.resblocks.{i}.attn.proj" for i in range(n_blocks)])

    precomputed_deltas = precompute_lora_deltas(source_models, lora_sites, device)

    if config.experiment.hypernet_type == "mlp":
        domain_dim = 512
        num_domains = len(source_models)
        emb_dim = 128
        hidden_dim = 256
        column_dim_qkv = 2304
        column_dim_proj = 768
        num_sites_max = 12
        dropout = 0.0
        
        hypernet = MLPHyperNet(
            domain_dim=domain_dim,
            num_domains=num_domains,
            emb_dim=emb_dim,
            hidden_dim=hidden_dim,
            column_dim_qkv=column_dim_qkv,
            column_dim_proj=column_dim_proj,
            num_sites_max=num_sites_max,
            dropout=dropout,
        ).to(device)

        hypernet_cfg = {
            "domain_dim": domain_dim,
            "emb_dim": emb_dim,
            "hidden_dim": hidden_dim,
            "column_dim_qkv": column_dim_qkv,
            "column_dim_proj": column_dim_proj,
            "num_sites_max": num_sites_max,
            "dropout": dropout,
            "params_trainable": sum(p.numel() for p in hypernet.parameters() if p.requires_grad),
        }
        wandb.config.update({"mlp_hypernet": hypernet_cfg}, allow_val_change=True)

    if config.experiment.hypernet_type in ["ca_column", "ca_layer", "ca_model"]:
        domain_dim = 512
        num_domains = len(source_models)
        emb_dim = 128
        hidden_dim = 256
        column_dim_qkv = 2304
        column_dim_proj = 768
        domain_token_scale = config.experiment[dataset_name].domain_token_scale
        
        # Select output mode
        if config.experiment.hypernet_type == "ca_layer":
            output_mode = "layer"
        elif config.experiment.hypernet_type == "ca_model":
            output_mode = "model"
        else: # ca_column
            output_mode = "column"

        hypernet = CrossAttnHyperNet(
            domain_dim=domain_dim,
            num_domains=num_domains,
            emb_dim=emb_dim,
            hidden_dim=hidden_dim,
            column_dim_qkv=column_dim_qkv,
            column_dim_proj=column_dim_proj,
            domain_token_scale=domain_token_scale,
            output_mode=output_mode,
        ).to(device)
        hypernet_cfg = {
            "domain_dim": domain_dim,
            "emb_dim": emb_dim,
            "hidden_dim": hidden_dim,
            "column_dim_qkv": column_dim_qkv,
            "column_dim_proj": column_dim_proj,
            "domain_token_scale": domain_token_scale,
            "output_mode": output_mode,
            "params_trainable": sum(p.numel() for p in hypernet.parameters() if p.requires_grad),
        }
        wandb.config.update({"cross_attn_hypernet": hypernet_cfg}, allow_val_change=True)

    optimizer = torch.optim.AdamW(hypernet.parameters(),lr=config.experiment[dataset_name].lr_clip, betas=(0.9, 0.999), weight_decay=0.01)
    hypernet.train()

    print(f"Number of trainable parameters in hypernet: {sum(p.numel() for p in hypernet.parameters() if p.requires_grad)}")

    image_loss_func = nn.CrossEntropyLoss()

    # ------------------------------------------------------------------ #
    #  Functions used for generating heatmaps                            #
    # ------------------------------------------------------------------ #
    def _compute_domain_rep_for(domain_name: str, preprocess):
        """Compute a target-domain representation from a batch of images."""
        test_dataset, test_loader = get_dataloader(
            dataset_name=dataset_name,
            trial_num=trial,
            path_to_annotations_dir=config.dataset.annotations_dir if dataset_name == "domainnet" else None,
            path_to_dataset_dir=config.dataset.data_dir,
            domains=[domain_name],
            classes=config.dataset.classes,
            transform=preprocess,
            split_type="test",
            batch_size=config.experiment.batch_size,
            num_workers=config.experiment.num_workers,
            pin_memory=config.experiment.pin_memory,
            is_training=True
        )
        images, _, _ = next(iter(test_loader))
        images = images.to(device)
        with torch.no_grad():
            feats = base_model.encode_image(images).float()
            feats = F.normalize(feats, dim=-1, eps=1e-6)
            return feats.mean(dim=0)  # [domain_dim]

    def _log_test_domain_heatmaps(tag_prefix: str, preprocess, step: int):
        """Log hypernet domain-weight heatmaps to W&B for each test domain."""
        for td in test_domains:
            try:
                dom_rep = _compute_domain_rep_for(td, preprocess)
                weights_info = compute_domain_weights(
                    hypernet=hypernet,
                    deltas_by_site=precomputed_deltas,
                    site_order=lora_sites,
                    domain_order=train_domains,
                    domain_representation=dom_rep,
                    device=device,
                    mask_domain_name=None,
                )
                vis = render_domain_weight_heatmaps(
                    weights_info,
                    domain_labels=train_domains,
                    out_dir=f"{config.constants.model_save_path}/visualizations",
                    prefix=f"{config.experiment.name}_{dataset_name}_{td}",
                    step=step,
                )
                to_log = {}
                if 'qkv' in vis:
                    to_log[f"{dataset_name}/{tag_prefix}/{td}/qkv_heatmap"] = wandb.Image(vis['qkv']['path'])
                if 'proj' in vis:
                    to_log[f"{dataset_name}/{tag_prefix}/{td}/proj_heatmap"] = wandb.Image(vis['proj']['path'])
                if to_log:
                    wandb.log(to_log)
            except Exception as e:
                print(f"[WARN] Test-domain heatmap logging failed for {td}: {e}")

    # Before training starts -> uncomment if you want to log heatmaps before training
    #_log_test_domain_heatmaps(tag_prefix="weights_before", preprocess=val_preprocess, step=0)

    # ------------------------------------------------------------------ #
    #  Meta‑training loop                                                #
    # ------------------------------------------------------------------ #

    outer_epochs = config.experiment[dataset_name].num_epochs
    epoch_bar = tqdm(range(outer_epochs), desc="Epochs", ncols=120)
    inner_steps = 1000
    wandb.config.update({"inner_steps": inner_steps}, allow_val_change=True)

    for epoch in epoch_bar:

        train_loss, train_accuracy, train_num_correct, train_total_samples = 0.0, 0.0, 0, 0
        inner_bar = tqdm(range(inner_steps), leave=False, ncols=100)

        all_preds = []
        all_labels = []
        all_metadata = []

        for inner_step in inner_bar:
            pseudo_target = random.choice(train_domains)
            # -------- support & query batches --------------------------
            try:
                support_batch = next(data_iters[pseudo_target])
            except StopIteration:
                # Update generator seed before creating new iterator
                generator.manual_seed(int(generator.initial_seed()) + 1)
                new_datasets, new_loaders, _ = create_data_iters(
                    config, trial, train_preprocess, [pseudo_target], dataset_name, generator
                )
                datasets[pseudo_target] = new_datasets[pseudo_target]
                data_loaders[pseudo_target] = new_loaders[pseudo_target]
                data_iters[pseudo_target] = iter(data_loaders[pseudo_target])
                support_batch = next(data_iters[pseudo_target])

            try:
                query_batch = next(data_iters[pseudo_target])
            except StopIteration:
                # Update generator seed before creating new iterator
                generator.manual_seed(int(generator.initial_seed()) + 1)
                new_datasets, new_loaders, _ = create_data_iters(
                    config, trial, train_preprocess, [pseudo_target], dataset_name, generator
                )
                datasets[pseudo_target] = new_datasets[pseudo_target]
                data_loaders[pseudo_target] = new_loaders[pseudo_target]
                data_iters[pseudo_target] = iter(data_loaders[pseudo_target])
                query_batch = next(data_iters[pseudo_target])

            support_images, _, _ = support_batch
            query_images, query_labels, query_metadata = query_batch
            support_images = support_images.to(device)
            query_images, query_labels = query_images.to(device), query_labels.to(device)

            # -------- merge LoRA deltas via hyper‑net ------------------
            with torch.no_grad():
                s_feats = base_model.encode_image(support_images).float()
                s_feats = F.normalize(s_feats, dim=-1, eps=1e-6)
                domain_rep = s_feats.mean(dim=0)  # [domain_dim]

            if dataset_name == "iwildcam":
                pseudo_target = None

            delta_dict = generate_merge_deltas(
                hypernet=hypernet,
                deltas_by_site=precomputed_deltas,
                site_order=lora_sites,
                domain_order=train_domains,
                domain_representation=domain_rep,
                device=device,
                mask_domain_name=pseudo_target,
            )
            
            # -------- compute loss with hooks --------------------------
            remove_all_hooks_(base_model)
            # Enable debug verification for the first iteration of the first epoch
            debug_verification = (int(epoch) == 0 and int(inner_step) == 0)
            
            # Store original weight for verification (only for debug)
            if debug_verification and lora_sites:
                debug_module = _find_module(base_model, lora_sites[0])
                original_weight_before = debug_module.weight.clone()

            handles = register_lora_hooks(base_model, delta_dict, debug_verification=debug_verification)
            try:
                image_features = base_model.encode_image(query_images).float()
                image_features = F.normalize(image_features, dim=-1, eps=1e-6)
                logits = base_model.logit_scale.exp() * (image_features @ text_features)
                loss = image_loss_func(logits, query_labels)
                preds = logits.argmax(-1)
                    
            finally:
                for handle in handles:
                    handle.remove()
                
                # Verify base weights are unchanged after hook removal
                if debug_verification and lora_sites:
                    weight_unchanged = torch.equal(debug_module.weight, original_weight_before)
                    print(f"\n=== POST-HOOK VERIFICATION ===")
                    print(f"Base model weights unchanged after hook removal: {weight_unchanged}")
                    print(f"Weight after hooks (first 3x3): \n{debug_module.weight[:3, :3]}")
                    print("=" * 30)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            all_preds.append(preds)
            all_labels.append(query_labels)
            all_metadata.append(query_metadata)

            train_num_correct += (preds == query_labels).sum().item()
            train_total_samples += query_labels.size(0)
            train_loss += loss.item()

        train_loss = train_loss / inner_steps
        train_accuracy = train_num_correct / train_total_samples

        all_preds = torch.cat(all_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        all_metadata = torch.cat(all_metadata, dim=0)

        if dataset_name in config.constants.wilds_datasets:
            all_preds_cpu = all_preds.cpu()
            all_labels_cpu = all_labels.cpu()
            all_metadata_cpu = all_metadata.cpu()
            train_results, _ = datasets[train_domains[0]].eval(all_preds_cpu, all_labels_cpu, all_metadata_cpu)
            train_results["train_loss"] = train_loss
        else:
            train_results = {"train_accuracy": train_accuracy, "train_loss": train_loss}
            
        domain_prefix = test_domains[0] if test_domains is not None else "global"
        log_dict = {
            f"{dataset_name}/train/{domain_prefix}/epoch": epoch,
        }
        for metric_name, metric_value in train_results.items():
            clean_name = metric_name.lower().replace(" ", "_")
            log_dict[f"{dataset_name}/train/{domain_prefix}/{clean_name}"] = metric_value

        print("--------------------------------")
        print(f"Train results on {train_domains}: \n", train_results)

        test_results = test_hypernet_accuracy(
            config, trial, base_model, hypernet,
            source_models, lora_sites, text_features,
            test_domains, val_preprocess, precomputed_deltas)
        for metric_name, metric_value in test_results.items():
            clean_name = metric_name.lower().replace(" ", "_")
            log_dict[f"{dataset_name}/test/{domain_prefix}/{clean_name}"] = metric_value
        print(f"Test results on {test_domains}: \n", test_results)

        wandb.log(log_dict)
        print("--------------------------------")

        # After each epoch: log test-domain heatmaps
        #_log_test_domain_heatmaps(tag_prefix="weights_epoch_end", preprocess=val_preprocess, step=epoch + 1)
    
    save_dir = os.path.join(config.constants.model_save_path, f"{config.experiment.name}_{dataset_name}_{backbone.replace('/', '-')}")
    os.makedirs(save_dir, exist_ok=True)
    if dataset_name == "domainnet": 
        model_name = f"trial{trial}_{config.experiment.name}_{dataset_name}_{test_domains[0]}_{config.experiment.hypernet_type}.pt"
    else:
        model_name = f"trial{trial}_{config.experiment.name}_{dataset_name}_{config.experiment.hypernet_type}.pt"
    torch.save({'model_state_dict': hypernet.state_dict()}, os.path.join(save_dir, model_name))
    
    return test_results

