"""Helper utilities for training and evaluation:
- Saving/loading training checkpoints and experiment results
- Reproducibility helpers (seeding, worker init)
- Loading of per-domain LoRA adapters
- Encoding text features for zero-shot classification with CLIP
"""

import os
import json
import random
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import open_clip
from open_clip.tokenizer import SimpleTokenizer
from peft import LoraConfig, get_peft_model
from experiments.clipora.lora.inject import inject_linear_attention
from omegaconf import DictConfig
from omegaconf.listconfig import ListConfig
import torch.serialization

def save_training_state(model: nn.Module, file_name: str, directory: str):
    """Save the state of a model to a checkpoint file.

    Parameters
    ----------
    model : nn.Module
        Model whose parameters will be saved.
    file_name : str
        File name of the checkpoint (e.g., "trial0_model.pt").
    directory : str
        Directory path where the checkpoint will be saved.
    """
    file_path = os.path.join(directory, file_name)
    os.makedirs(directory, exist_ok=True)
    print(f'Saving checkpoint to {file_path}...')
    torch.save({
        'model_state_dict': model.state_dict()
    }, file_path)


def load_training_state(file_name: str, model_save_path: str):
    """Load a model state dict from a checkpoint saved via save_training_state.

    Parameters
    ----------
    file_name : str
        File name of the checkpoint to load.
    model_save_path : str
        Directory where the checkpoint is stored.

    Returns
    -------
    dict
        The model state dict contained in the checkpoint.

    Raises
    ------
    TypeError
        If the loaded checkpoint is not a dictionary.
    KeyError
        If the expected 'model_state_dict' key is missing.
    """
    torch.serialization.add_safe_globals([ListConfig])
    file_path = os.path.join(model_save_path, file_name)
    print(f'Loading checkpoint from {file_path}...')
    checkpoint = torch.load(file_path, map_location=torch.device('cpu'), weights_only=False)
    if not isinstance(checkpoint, dict):
        raise TypeError(f'Checkpoint must be a dict containing "model_state_dict". Got type: {type(checkpoint)}')
    if 'model_state_dict' not in checkpoint:
        available_keys = list(checkpoint.keys())
        raise KeyError(f'Checkpoint missing "model_state_dict". Available keys: {available_keys}')
    if len(checkpoint.keys()) > 1:
        print('Multiple keys found in checkpoint; using only "model_state_dict" and ignoring the rest.')
    return checkpoint['model_state_dict']


def save_results(results_dict: dict, save_dir: str): 
    """Append experiment results to a JSON file.

    Parameters
    ----------
    results_dict : dict
        Dictionary of results/metrics to record.
    save_dir : str
        Directory where the results file will be created/appended.
    """
    os.makedirs(save_dir, exist_ok=True)
    test_stats = os.path.join(save_dir, 'test_stats.json')
    with open(test_stats, 'a') as fp:
        fp.write("\n--------------------------------------------------------------\n")
        fp.write(json.dumps(results_dict) + "\n")


def set_seeds(random_seed: int):
    """Set seeds and deterministic flags for reproducible experiments.

    Parameters
    ----------
    random_seed : int
        The base random seed to use across libraries and CUDA.
    """
    os.environ["PYTHONHASHSEED"] = str(random_seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    """Initialize NumPy, Python, and Torch RNGs per DataLoader worker."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    torch.manual_seed(worker_seed)


def load_lora_model(lora_params: DictConfig, backbone: str, device: str="cuda"):
    """Create a CLIP model with LoRA adapters injected.

    Parameters
    ----------
    lora_params : DictConfig
        LoRA configuration (r, alpha, dropout, target modules).
    backbone : str
        CLIP backbone name (e.g., "ViT-B-16").
    device : str, optional
        Device on which to place the model, by default "cuda".

    Returns
    -------
    nn.Module
        A CLIP model with LoRA adapters initialized.
    """
    model, _, preprocess = open_clip.create_model_and_transforms(backbone, pretrained="openai")
    tokenizer = open_clip.get_tokenizer(backbone)
    model.to(device)
    model_config = open_clip.get_model_config(backbone)

    # Only use LoRA for the vision encoder
    lora_text = False
    lora_vision = True

    num_heads = 12 # 12 for ViT-B-16, 16 for ViT-L-14

    if lora_text:
        model = inject_linear_attention(
            model=model,
            encoders={"transformer"},
            embed_dim=model_config["embed_dim"],
            num_heads=model_config["text_cfg"]["heads"],
        )
    if lora_vision:
        model = inject_linear_attention(
            model=model,
            encoders={"visual.transformer"},
            embed_dim=model_config["vision_cfg"]["width"],
            num_heads=num_heads,
        )
    lora_config = LoraConfig(
        r=lora_params.lora_r,
        lora_alpha=lora_params.lora_alpha,
        lora_dropout=lora_params.lora_dropout,
        target_modules=lora_params.lora_target_modules,
    )
    model = get_peft_model(model, lora_config)

    if hasattr(model.visual, "print_trainable_parameters"):
        model.visual.print_trainable_parameters()
    return model


def load_source_models(
        lora_config: DictConfig, 
        train_domains: list[str], 
        dataset_name: str, 
        model_save_path: str, 
        backbone: str, 
        trial_num: int, 
        device: str="cuda"
    ): 
    """Load per-source-domain LoRA models from saved checkpoints.

    Parameters
    ----------
    lora_config : DictConfig
        LoRA configuration used when constructing each model.
    train_domains : list[str]
        List of source domain identifiers whose models to load.
    dataset_name : str
        Dataset name used to build checkpoint paths.
    model_save_path : str
        Root directory where checkpoints are stored.
    backbone : str
        CLIP backbone name.
    trial_num : int
        Trial index used in checkpoint naming.
    device : str, optional
        Device to load models onto.

    Returns
    -------
    dict[str, nn.Module]
        Mapping from domain name to loaded LoRA model.
    """
    models = {}
    for source_domain in train_domains:
        save_dir = os.path.join(model_save_path, f'source_lora_per_domain_{dataset_name}_{backbone.replace("/", "-")}')
        model_name = f'trial{trial_num}_source_lora_per_domain_{dataset_name}_{source_domain}.pt'
        model = load_lora_model(lora_config, backbone, device=device)
        model_state_dict = load_training_state(model_name, save_dir)
        model.load_state_dict(model_state_dict)
        models[source_domain] = model
    return models

def encode_text(
        class_names: list[str], 
        prompt_templates: list[str], 
        model: nn.Module, 
        tokenizer: SimpleTokenizer, 
        device: str="cuda"
    ):
    """Encode class names into CLIP text embeddings using prompt templates.

    Parameters
    ----------
    class_names : list[str]
        Class names to encode (underscores will be replaced by spaces).
    prompt_templates : list[str]
        List of formatting templates, e.g., "a photo of a {}".
    model : nn.Module
        CLIP model providing the text encoder.
    tokenizer : SimpleTokenizer
        Tokenizer compatible with the given CLIP model.
    device : str, optional
        Device on which to place tokenized text.

    Returns
    -------
    torch.Tensor
        A tensor of shape [D, C] where each column is a normalized
        class embedding averaged over templates.
    """
    with torch.no_grad():
        text_features = []
        for class_name in tqdm(class_names):
            texts = [template.format(class_name.replace("_", " ").lower()) for template in prompt_templates]
            texts = tokenizer(texts).to(device)
            # Handle DataParallel models
            if hasattr(model, 'module'):
                class_embeddings = model.module.encode_text(texts)
            else:
                class_embeddings = model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            text_features.append(class_embedding)
        text_features = torch.stack(text_features, dim=1).to(device)
    return text_features