"""
Model loading utilities for BigBird experiments.

Provides functions to load BigBird model, tokenizer, and extract embeddings.
"""
import os
from typing import Dict, Any, Optional, Union

import torch
from transformers import AutoTokenizer, BigBirdModel, BigBirdTokenizer


# Global variable for model input device
MODEL_INPUT_DEVICE = None
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model configurations
MODEL_CONFIGS = {
    "base": {
        "model_id": "google/bigbird-roberta-base",
        "hidden_size": 768,
        "num_attention_heads": 12,
    },
    "large": {
        "model_id": "google/bigbird-roberta-large",
        "hidden_size": 1024,
        "num_attention_heads": 16,
    },
}


def get_model_config(model_size: str = "base") -> Dict[str, Any]:
    """Get model configuration for specified size."""
    if model_size not in MODEL_CONFIGS:
        raise ValueError(f"Unknown model_size: {model_size}. Available: {list(MODEL_CONFIGS.keys())}")
    return MODEL_CONFIGS[model_size]


def load_model(
    model_name: str = "BigBird",
    attention_type: str = "block_sparse",
    max_length: Optional[int] = None,
    model_size: str = "base",
    device_map: Optional[Union[str, Dict[str, Any]]] = None,
    device: Optional[torch.device] = None,
    torch_dtype: Optional[torch.dtype] = None,
):
    """
    Load BigBird model from HuggingFace in inference mode.

    Args:
        model_name: Model family (currently only "BigBird" supported)
        attention_type: "block_sparse" or "original_full"
        max_length: Maximum sequence length (interpolates position embeddings if > 4096)
        model_size: "base" or "large"
        device_map: None for single device, "auto" for multi-GPU sharding
        device: Target device for single-device mode
        torch_dtype: Optional dtype for model weights

    Returns:
        Loaded BigBird model
    """
    global MODEL_INPUT_DEVICE

    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if model_name != "BigBird":
        raise ValueError("Only BigBird model is supported")

    config_dict = get_model_config(model_size)
    model_id = config_dict["model_id"]
    print(f"Loading {model_id} (hidden_size={config_dict['hidden_size']}, heads={config_dict['num_attention_heads']})")

    offline_mode = os.environ.get("HF_HUB_OFFLINE", "0") == "1"

    # Load model on CPU first
    print("Loading model on CPU...")
    try:
        bigbird = BigBirdModel.from_pretrained(model_id, torch_dtype=torch_dtype)
    except Exception as e:
        print(f"First load attempt failed: {e}")
        fallback_kwargs = {"revision": "refs/pr/2", "use_safetensors": True}
        if not offline_mode:
            fallback_kwargs["force_download"] = True
        print(f"Retrying with fallback_kwargs: {fallback_kwargs}")
        bigbird = BigBirdModel.from_pretrained(model_id, torch_dtype=torch_dtype, **fallback_kwargs)

    bigbird.eval()
    bigbird.set_attention_type(attention_type)

    # Place on device
    if device_map is None:
        bigbird.to(device)
        model_input_device = device
        print(f"Model loaded on single device: {model_input_device}")
    else:
        from accelerate import infer_auto_device_map, dispatch_model
        print(f"Sharding model across GPUs (device_map={device_map})...")

        if device_map == "auto":
            device_map_computed = infer_auto_device_map(
                bigbird,
                no_split_module_classes=["BigBirdLayer"],
            )
        else:
            device_map_computed = device_map

        print(f"Computed device_map: {device_map_computed}")
        bigbird = dispatch_model(bigbird, device_map=device_map_computed)
        model_input_device = next(iter(bigbird.parameters())).device
        print(f"Model sharded across GPUs. Input device: {model_input_device}")

    MODEL_INPUT_DEVICE = model_input_device
    return bigbird


def load_tokenizer(model_name: str = "BigBird", model_size: str = "base"):
    """
    Load tokenizer for the specified model.

    Args:
        model_name: Model family (currently only "BigBird" supported)
        model_size: "base" or "large"

    Returns:
        Tokenizer instance
    """
    if model_name == "BigBird":
        config = get_model_config(model_size)
        model_id = config["model_id"]
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
        except Exception:
            tokenizer = BigBirdTokenizer.from_pretrained(model_id)
    else:
        raise ValueError("Only BigBird model is supported")
    return tokenizer


def get_layer0_embeddings_direct(model, tokens):
    """
    Get embeddings directly WITHOUT running attention for layer 0.

    This extracts token embeddings + position embeddings + token type embeddings
    followed by LayerNorm, which is the input to the first transformer layer.

    Args:
        model: BigBird model instance
        tokens: Dictionary with 'input_ids' tensor

    Returns:
        Embeddings tensor of shape [batch, seq_len, hidden_size]
    """
    with torch.no_grad():
        input_ids = tokens['input_ids']
        token_embeds = model.embeddings.word_embeddings(input_ids)
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
        position_embeds = model.embeddings.position_embeddings(position_ids)
        token_type_ids = torch.zeros_like(input_ids)
        token_type_embeds = model.embeddings.token_type_embeddings(token_type_ids)
        embeddings = token_embeds + position_embeds + token_type_embeds
        embeddings = model.embeddings.LayerNorm(embeddings)
    return embeddings
