import gc
import os
import time
from contextlib import contextmanager
from typing import Any, Literal

import torch
from torch import OutOfMemoryError, Tensor, nn
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from ml_utils import log, setup_huggingface


def get_device() -> Literal["cuda", "cpu"]:
    return "cuda" if torch.cuda.is_available() else "cpu"


def init_torch(matmul_precision: Literal["highest", "high", "medium"] = "high") -> None:
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.synchronize()

        torch.set_float32_matmul_precision(matmul_precision)

        device = torch.device("cuda:0")
        props = torch.cuda.get_device_properties(device)
        cap_major, cap_minor = torch.cuda.get_device_capability(device)
        log.info("Initializing torch environment...")
        log.info(f"CUDA available: {torch.cuda.is_available()}")
        log.info(f"Device name: {props.name}")
        log.info(f"Compute capability: {cap_major}.{cap_minor}")
        log.info(f"Total memory: {props.total_memory / 1024**3:.2f} GB")
        # https://github.com/pytorch/pytorch/issues/124996
        log.info(f"BF16 supported by matmul: {torch.cuda.is_bf16_supported(including_emulation=False)}")
        log.info(f"TF32 allowed on matmul: {torch.backends.cuda.matmul.allow_tf32}")
        log.info(f"TF32 allowed on cudnn: {torch.backends.cudnn.allow_tf32}")
    else:
        log.info("CUDA not available.")


def set_seed(seed: int = 42) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@contextmanager
def torch_timer(name: str = "training"):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.synchronize()
    start_time = time.time()
    try:
        yield
    finally:
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        end_time = time.time()
        duration = end_time - start_time
        mem_mb = int(torch.cuda.max_memory_allocated() / (1024**2))
        print(f"[{name}] Time: {duration:.3f}s | Max memory: {mem_mb} MB")  # noqa: T201


def from_hf_to_hf(
    mdl_name: str,
    *,
    clf: bool = False,
    causal_lm: bool = False,
    to_repo: str = "aarabil",
) -> None:
    if clf and causal_lm:
        msg = "Choose at most one of 'clf' or 'causal_lm'."
        raise ValueError(msg)

    setup_huggingface()
    base_name = mdl_name.split("/", 1)[-1]

    model_cls = AutoModelForSequenceClassification if clf else AutoModelForCausalLM if causal_lm else AutoModel
    model = model_cls.from_pretrained(mdl_name, token=os.getenv("HF_TOKEN"))
    tokenizer = AutoTokenizer.from_pretrained(mdl_name, token=os.getenv("HF_TOKEN"))

    hub_path = f"{to_repo}/{base_name}"
    model.push_to_hub(hub_path, token=os.getenv("HF_TOKEN"))
    tokenizer.push_to_hub(hub_path, token=os.getenv("HF_TOKEN"))


def get_param_info(mdl: nn.Module) -> dict[str, Any]:
    param_info = {
        name: {
            "shape": list(param.shape),
            "values": param.data.view(-1)[:5].cpu().tolist(),
        }
        for name, param in mdl.named_parameters()
    }
    n_params = sum(p.numel() for p in mdl.parameters()) / 1e6
    return {"n_params (in million)": n_params, "params": param_info}


def search_max_bs(
    mdl: nn.Module,
    tokenizer: PreTrainedTokenizerBase,
    seq_len: int,
    mode: Literal["train", "test"],
    *,
    use_amp: bool = True,
) -> int:
    init_torch()
    dev = get_device()
    mdl = mdl.to(dev).train(mode=mode == "train")
    loss_fn = torch.nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(mdl.parameters(), lr=2e-5)
    scaler = torch.amp.GradScaler(enabled=use_amp)  # type: ignore

    # Dummy tensor of shape (1, S) to be expanded to (B,S) with B the current bs
    pad_id: int = tokenizer.pad_token_id or 0  # type: ignore
    dummy_x = torch.full(size=(1, seq_len), fill_value=pad_id, dtype=torch.long, device=dev)
    dummy_y = torch.ones(size=(1,), dtype=torch.long, device=dev)

    # Determine crude upperbound of max bs
    vram = torch.cuda.get_device_properties(dev).total_memory
    lo, hi = 1, 2**14 if vram > 30e9 else 2**11  # noqa: PLR2004

    # Binary search
    while lo < hi:
        mid = (lo + hi + 1) // 2
        try:
            # x=(B,S);y=(B,)
            x = dummy_x.expand(mid, -1)
            y = dummy_y.expand(mid)
            torch.cuda.empty_cache()

            # -- Forward
            with torch.autocast(device_type=dev, enabled=use_amp):
                out = mdl(x)
                # (B, S, E)
                hidden_embeddings = (
                    out
                    if isinstance(out, Tensor)
                    else out.last_hidden_state
                    if hasattr(out, "last_hidden_state")
                    else out[0]
                )
                logits = hidden_embeddings.mean(dim=1)
                loss = loss_fn(logits, y)

            # -- Backward
            if mode == "train":
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()

            torch.cuda.empty_cache()
            lo = mid  # increase bs

        except OutOfMemoryError:
            torch.cuda.empty_cache()
            hi = mid - 1  # decrease bs

        finally:
            mdl.zero_grad(set_to_none=True)
            torch.cuda.empty_cache()

    print(f"\n Max batch for {mode}: {lo} ✅ ")  # noqa: T201
    return lo
