"""Functions for loading and interacting with pretrained language models.

This module is designed to house all the annoying branching logic
that comes with supporting analysis of many slightly different model
implementations.
"""

import argparse
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterator, Literal, Optional, Sequence, overload

from src.utils import env_utils, tokenizer_utils
from src.utils.typing import Device, Layer, Mamba, Model, ModelInput, Tokenizer

import baukit
import torch
import transformers

logger = logging.getLogger(__name__)

GPT_J_NAME_SHORT = "gptj"  # A useful alias for the CLI.
GPT_J_NAME = "EleutherAI/gpt-j-6B"

GPT_NEO_X_NAME_SHORT = "neox"
GPT_NEO_X_NAME = "EleutherAI/gpt-neox-20b"

LLAMA_13B_NAME = "llama-13b"
LLAMA_30B_NAME = "llama-30b"
LLAMA_NAME_SHORT = "llama"

MAMBA_3B_NAME = "state-spaces/mamba-2.8b-slimpj"
MAMBA_3B_SHORT = "mamba-3b"

PYTHIA_3B_NAME = "EleutherAI/pythia-2.8b-deduped"
PYTHIA_3B_SHORT = "pythia-3b"

DOWNLOADABLE_MODELS = frozenset({GPT_J_NAME, GPT_NEO_X_NAME, "gpt2-xl", MAMBA_3B_NAME})


@dataclass(frozen=True)
class ModelAndTokenizer:
    """A pretrained model and its tokenizer."""

    model: Model
    tokenizer: Tokenizer

    @property
    def lm_head(self) -> torch.nn.Sequential:
        """Return the LM head."""
        ln_f = baukit.get_module(self.model, determine_final_layer_norm_path(self))
        lm_head = baukit.get_module(self.model, determine_lm_head_path(self))
        return torch.nn.Sequential(ln_f, lm_head)

    @property
    def name(self) -> str:
        """Return human-readable name for this model."""
        if isinstance(self.model, transformers.GPTNeoXForCausalLM):
            return GPT_NEO_X_NAME_SHORT
        elif isinstance(self.model, transformers.LlamaForCausalLM):
            return LLAMA_NAME_SHORT
        elif isinstance(self.model, transformers.GPTJForCausalLM):
            return GPT_J_NAME_SHORT
        elif isinstance(self.model, transformers.GPT2LMHeadModel):
            return "gpt2-xl"
        elif isinstance(self.model, Mamba):
            return MAMBA_3B_SHORT
        else:
            raise ValueError(f"unknown model name: {type(self.model).__name__}")

    def to_(self, device: Optional[Device]) -> None:
        """Send model to the device."""
        self.model.to(device)

    def eval_(self) -> None:
        """Set model to eval mode."""
        self.model.eval()

    @property
    def is_mamba(self) -> bool:
        return isinstance(self.model, Mamba) or "mamba" in self.name.lower()

    @property
    def is_mamba_fast(self) -> bool:
        return self.is_mamba and hasattr(self.model, "backbone")

    def __call__(self, *args, **kwargs) -> Any:
        """Call the model."""
        # print(f"{self.is_mamba=} | {self.is_mamba_fast=}")
        if self.is_mamba:  # Mamba can only handle input_ids
            for k in list(kwargs.keys()):
                if k.startswith("input") == False:
                    kwargs.pop(k)
        return self.model(*args, **kwargs)


def unwrap_model(value: Model | ModelAndTokenizer) -> Model:
    """Unwrap the model if necessary."""
    if isinstance(value, ModelAndTokenizer):
        return value.model
    return value


def unwrap_tokenizer(tokenizer: ModelAndTokenizer | Tokenizer) -> Tokenizer:
    """Unwrap the tokenizer."""
    if isinstance(tokenizer, ModelAndTokenizer):
        return tokenizer.tokenizer
    return tokenizer


def determine_embedding_layer_path(model: ModelAndTokenizer | Model) -> str:
    model = unwrap_model(model)
    if is_gpt_variant(model):
        return "transformer.wte"
    elif isinstance(model, transformers.LlamaForCausalLM):
        return "model.embed_tokens"
    elif isinstance(model, Mamba):
        prefix = "backbone." if hasattr(model, "backbone") else ""
        return prefix + "embedding"
    elif is_pythia_variant(model):
        return "gpt_neox.embed_in"
    else:
        raise ValueError(f"unknown model type: {type(model).__name__}")


def determine_final_layer_norm_path(model: ModelAndTokenizer | Model) -> str:
    model = unwrap_model(model)
    if is_gpt_variant(model):
        return "transformer.ln_f"
    elif isinstance(model, transformers.LlamaForCausalLM):
        return "model.norm"
    elif isinstance(model, Mamba):
        prefix = "backbone." if hasattr(model, "backbone") else ""
        return prefix + "norm_f"
    elif is_pythia_variant(model):
        return "gpt_neox.final_layer_norm"
    else:
        raise ValueError(f"unknown model type: {type(model).__name__}")


def determine_lm_head_path(model: ModelAndTokenizer | Model) -> str:
    model = unwrap_model(model)
    if is_gpt_variant(model):
        return "lm_head"
    elif isinstance(model, transformers.LlamaForCausalLM):
        return "model.lm_head"
    elif isinstance(model, Mamba):
        return "lm_head"
    elif is_pythia_variant(model):
        return "embed_out"
    else:
        raise ValueError(f"unknown model type: {type(model).__name__}")


def determine_layers(model: ModelAndTokenizer | Model) -> tuple[int, ...]:
    """Return all hidden layer names for the given model."""
    model = unwrap_model(model)
    assert isinstance(model, Model)

    if isinstance(
        model, transformers.GPTNeoXForCausalLM | transformers.LlamaForCausalLM
    ):
        n_layer = model.config.num_hidden_layers
    elif isinstance(model, Mamba):
        n_layer = (
            len(model.backbone.layers)
            if hasattr(model, "backbone")
            else len(model.layers)
        )
    else:
        n_layer = model.config.n_layer

    return (*range(n_layer),)


@overload
def determine_layer_paths(
    model: ModelAndTokenizer | Model,
    layers: Optional[Sequence[Layer]] = ...,
    *,
    return_dict: Literal[False] = ...,
) -> Sequence[str]:
    """Determine layer path for each layer."""
    ...


@overload
def determine_layer_paths(
    model: ModelAndTokenizer | Model,
    layers: Optional[Sequence[Layer]] = ...,
    *,
    return_dict: Literal[True],
) -> dict[Layer, str]:
    """Determine mapping from layer to layer path."""
    ...


def determine_layer_paths(
    model: ModelAndTokenizer | Model,
    layers: Optional[Sequence[Layer]] = None,
    *,
    return_dict: bool = False,
) -> Sequence[str] | dict[Layer, str]:
    """Determine the absolute paths to the given layers in the model.

    Args:
        model: The model.
        layers: The specific layer (numbers/"emb") to look at. Defaults to all of them.
            Can be a negative number.
        return_dict: If True, return mapping from layer to layer path,
            otherwise just return list of layer paths in same order as `layers`.

    Returns:
        Mapping from layer number to layer path.

    """
    model = unwrap_model(model)

    if layers is None:
        layers = determine_layers(model)

    assert isinstance(model, Model), type(model)

    layer_paths: dict[Layer, str] = {}
    for layer in layers:
        if layer == "emb":
            layer_paths[layer] = determine_embedding_layer_path(model)
            continue
        if layer == "ln_f":
            layer_paths[layer] = determine_final_layer_norm_path(model)
            continue

        layer_index = layer
        if layer_index < 0:
            layer_index = len(determine_layers(model)) + layer

        if isinstance(model, transformers.GPTNeoXForCausalLM):
            layer_path = f"gpt_neox.layers.{layer_index}"
        elif isinstance(model, transformers.LlamaForCausalLM):
            layer_path = f"model.layers.{layer_index}"
        elif isinstance(model, Mamba):
            prefix = "backbone." if hasattr(model, "backbone") else ""
            layer_path = prefix + f"layers.{layer_index}"
        else:
            layer_path = f"transformer.h.{layer_index}"
        layer_paths[layer] = layer_path

    return layer_paths if return_dict else tuple(layer_paths[la] for la in layers)


def determine_hidden_size(model: ModelAndTokenizer | Model) -> int:
    """Determine hidden rep size for the model."""
    model = unwrap_model(model)

    if isinstance(model, Mamba):
        prefix = "backbone." if hasattr(model, "backbone") else ""
        embed = baukit.get_module(model, prefix + "embedding")
        return embed.weight.shape[-1]

    return model.config.hidden_size


def determine_device(model: ModelAndTokenizer | Model) -> torch.device | None:
    """Determine device model is running on."""
    parameter = any_parameter(model)
    return parameter.device if parameter is not None else None


def determine_dtype(model: ModelAndTokenizer | Model) -> torch.dtype | None:
    """Determine dtype of model."""
    parameter = any_parameter(model)
    return parameter.dtype if parameter is not None else None


def any_parameter(model: ModelAndTokenizer | Model) -> torch.nn.Parameter | None:
    """Get any example parameter for the model."""
    model = unwrap_model(model)
    return next(iter(model.parameters()), None)


def tokenize_words(
    tokenizer: ModelAndTokenizer | Tokenizer,
    words: str | Sequence[str],
    spaces: bool = True,
    **kwargs: Any,
) -> ModelInput:
    """Return first token ID for word, accounting for whether model expects spaces."""
    tokenizer = unwrap_tokenizer(tokenizer)
    if isinstance(words, str):
        words = [words]

    if spaces and is_gpt_variant(tokenizer):
        words = [f" {word}" for word in words]

    kwargs.setdefault("add_special_tokens", False)
    kwargs.setdefault("padding", "longest")
    kwargs.setdefault("return_tensors", "pt")
    return tokenizer(words, **kwargs)


def maybe_prefix_eos(tokenizer: Tokenizer | ModelAndTokenizer, prompt: str) -> str:
    """Prefix prompt with EOS token if model has no special start token."""
    tokenizer = unwrap_tokenizer(tokenizer)
    if hasattr(tokenizer, "eos_token"):
        prefix = tokenizer.eos_token
        if not prompt.startswith(prefix):
            prompt = prefix + " " + prompt
    return prompt


def is_pythia_variant(model: Model | ModelAndTokenizer) -> bool:
    """Determine if model is GPT variant."""
    if isinstance(model, ModelAndTokenizer):
        model = unwrap_model(model)
    try:
        return "pythia" in model.config._name_or_path.lower()
    except:
        return False


def is_gpt_variant(mt: Model | Tokenizer | ModelAndTokenizer) -> bool:
    """Determine if model/tokenizer is GPT variant."""
    if isinstance(mt, ModelAndTokenizer):
        mt = unwrap_model(mt)

    # pythia models also have GPTNeoXForCausalLM architecture, but they have slightly  different structure
    # so we need to check for them separately
    if is_pythia_variant(mt):
        return False
    return isinstance(
        mt,
        transformers.GPT2LMHeadModel
        | transformers.GPTJForCausalLM
        | transformers.GPTNeoForCausalLM
        | transformers.GPTNeoXForCausalLM
        | transformers.GPT2TokenizerFast
        | transformers.GPTNeoXTokenizerFast,
    )


def determine_generate_kwargs(mt: ModelAndTokenizer) -> dict:
    """Determine default generate kwargs."""
    kwargs = {}
    if is_gpt_variant(mt):
        tokenizer = unwrap_tokenizer(mt)
        kwargs["pad_token_id"] = tokenizer.eos_token_id
    return kwargs


@contextmanager
def set_padding_side(
    tokenizer: Tokenizer | ModelAndTokenizer, padding_side: str = "right"
) -> Iterator[None]:
    """Wrap `tokenizer_utils.set_padding_side`."""
    tokenizer = unwrap_tokenizer(tokenizer)
    with tokenizer_utils.set_padding_side(tokenizer, padding_side=padding_side):
        yield


def map_to(
    orig: Any, device: Device | None = None, dtype: torch.dtype | None = None
) -> Any:
    """Map all tensors in the given value to the device.

    Args:
        orig: Any sequence of or mapping to tensors, or just a tensor.
        device: Device to send to.

    Returns:
        Same value, but with all tensors moved to the device.

    """
    if device is None and dtype is None:
        return orig

    result = orig
    if isinstance(orig, torch.Tensor):
        result = orig.to(device=device, dtype=dtype)
    elif isinstance(orig, dict):
        result = {
            key: map_to(value, device=device, dtype=dtype)
            for key, value in orig.items()
        }
    elif isinstance(orig, (list, tuple)):
        result = orig.__class__(
            map_to(value, device=device, dtype=dtype) for value in orig
        )
    assert isinstance(result, orig.__class__), f"{type(result)}/{type(orig)}"
    return result


def load_model(
    name: str, device: Optional[Device] = None, fp16: Optional[bool] = None
) -> ModelAndTokenizer:
    """Load the model given its string name.

    Args:
        name: Name of the model or path to it.
        device: If set, send model to this device. Defaults to CPU.
        fp16: Whether to use half precision. If not set, depends on model.

    Returns:
        ModelAndTokenizer: Loaded model and its tokenizer.

    """
    if name == GPT_J_NAME_SHORT:
        name = GPT_J_NAME
    elif name == GPT_NEO_X_NAME_SHORT:
        name = GPT_NEO_X_NAME
    elif name == LLAMA_NAME_SHORT:
        name = LLAMA_13B_NAME
    elif name == MAMBA_3B_SHORT:
        name = MAMBA_3B_NAME
    elif name == PYTHIA_3B_SHORT:
        name = PYTHIA_3B_NAME

    # I usually save randomly initialized variants under the short name of the
    # corresponding real model (e.g. gptj_random, neox_random), so check here
    # if we are dealing with *any* variant of the big model.
    is_gpt_j_variant = name == GPT_J_NAME or GPT_J_NAME_SHORT in name
    is_neo_x_variant = name == GPT_NEO_X_NAME or GPT_NEO_X_NAME_SHORT in name
    is_llama_variant = (
        name in {LLAMA_13B_NAME, LLAMA_30B_NAME} or LLAMA_NAME_SHORT in name
    )
    is_mamba_variant = "mamba" in name.lower()

    if fp16 is None:
        fp16 = is_gpt_j_variant or is_neo_x_variant or is_llama_variant

    torch_dtype = torch.float16 if fp16 else None

    kwargs: dict = dict(torch_dtype=torch_dtype)
    if is_gpt_j_variant:
        kwargs["low_cpu_mem_usage"] = True
        if fp16:
            kwargs["revision"] = "float16"

    # If model is not automatically downloadable from huggingface, assume it is
    # available locally in the project models directory.
    if name not in DOWNLOADABLE_MODELS:
        models_dir = env_utils.determine_models_dir()
        logger.debug(f"{name} not downloadable, will look for weights in {models_dir}")

        path = Path(name)
        if not path.is_absolute() and not path.is_relative_to(models_dir):
            name = str(models_dir / name)

    logger.info(f"loading {name} (device={device}, fp16={fp16})")

    if is_mamba_variant:
        model = Mamba.from_pretrained(name)
    else:
        model = transformers.AutoModelForCausalLM.from_pretrained(name, **kwargs)
    model.to(torch_dtype)
    model.to(device)
    model.eval()

    if is_llama_variant:
        tokenizer = transformers.LlamaTokenizerFast.from_pretrained(name)
        tokenizer.pad_token = tokenizer.eos_token = "</s>"
        tokenizer.pad_token_id = tokenizer.eos_token_id = 2
    else:
        if is_mamba_variant:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                "EleutherAI/gpt-neox-20b",  # Mamba was trained on the Pile with this exact tokenizer
            )
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained(name)
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(
        f"dtype: {determine_dtype(model)}, device: {determine_device(model)}, memory: {get_model_size(model, 'GB') :.2f} GB"
    )

    return ModelAndTokenizer(model, tokenizer)


def get_model_size(
    model: torch.nn.Module, unit: Literal["B", "KB", "MB", "GB"] = "MB"
) -> float:
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all = param_size + buffer_size
    denom = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30}[unit]
    return size_all / denom


def add_model_args(parser: argparse.ArgumentParser) -> None:
    """Add args needed to load a model.

    The args include:
        --model: The language model to load, defaulting to GPT-J.
        --device: The device to send model and inputs to.
        --fp16: Whether to use half precision version of the model.
            Note this is used as `--fp16 False` since default value depends on
            which model we are loading.
    """
    parser.add_argument(
        "--model",
        "-m",
        default=GPT_J_NAME_SHORT,
        help="model to edit",
    )
    parser.add_argument("--device", help="device to train on")
    parser.add_argument(
        "--fp16", action=argparse.BooleanOptionalAction, help="set whether to use fp16"
    )
