from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


@dataclass
class HFModelBundle:
    tokenizer: AutoTokenizer
    model: AutoModelForCausalLM
    device: torch.device


def load_hf_model(model_name: str, cache_dir: Optional[str] = None, device: Optional[str] = None) -> HFModelBundle:
    dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir=cache_dir,
        device_map="auto",
        trust_remote_code=True,
    ).to(dev)
    model.eval()
    return HFModelBundle(tokenizer=tokenizer, model=model, device=dev)


def generate_outputs_with_logits(
    bundle: HFModelBundle,
    prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
) -> Tuple[str, List[float], List[float], List[float]]:
    """Generate one completion and return generated text and per-token logits, log_probs, probs for output tokens."""
    tokenizer, model, device = bundle.tokenizer, bundle.model, bundle.device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True,
            output_scores=True,
        )
    generated_tokens = outputs.sequences[:, inputs.input_ids.shape[1]:]
    generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True).strip()

    logits = torch.stack(outputs.scores).to(device)  # (num_tokens, batch, vocab)
    logits = logits.permute(1, 0, 2)  # (batch, num_tokens, vocab)
    token_logits = logits[0, torch.arange(generated_tokens.shape[1]), generated_tokens[0]]
    probs = torch.nn.functional.softmax(logits, dim=-1)
    log_probs = torch.log(probs)
    token_probs = probs[0, torch.arange(generated_tokens.shape[1]), generated_tokens[0]]
    token_log_probs = log_probs[0, torch.arange(generated_tokens.shape[1]), generated_tokens[0]]

    return (
        generated_text,
        token_logits.detach().cpu().tolist(),
        token_log_probs.detach().cpu().tolist(),
        token_probs.detach().cpu().tolist(),
    )


def generate_multiple_outputs_with_logits(
    bundle: HFModelBundle,
    prompt: str,
    n_samples: int,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
) -> List[Dict[str, Any]]:
    """Generate n_samples completions with logits for each completion.

    Returns a list of dicts: {generated_text, token_logits, token_probs, token_log_probs}
    """
    tokenizer, model, device = bundle.tokenizer, bundle.model, bundle.device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            num_return_sequences=n_samples,
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True,
            output_scores=True,
        )

    prompt_len = inputs.input_ids.shape[1]
    sequences = outputs.sequences
    scores = outputs.scores  # list[num_tokens] of (batch*n_samples, vocab)
    logits_stacked = torch.stack(scores).permute(1, 0, 2)  # (batch*n_samples, num_tokens, vocab)

    all_outputs: List[Dict[str, Any]] = []
    for i, sequence in enumerate(sequences):
        generated_tokens = sequence[prompt_len:]
        generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
        logits = logits_stacked[i]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        log_probs = torch.log(probs)
        token_indices = generated_tokens
        token_logits = logits[torch.arange(len(token_indices)), token_indices].detach().cpu().tolist()
        token_probs = probs[torch.arange(len(token_indices)), token_indices].detach().cpu().tolist()
        token_log_probs = log_probs[torch.arange(len(token_indices)), token_indices].detach().cpu().tolist()

        all_outputs.append(
            {
                "generated_text": generated_text,
                "token_logits": token_logits,
                "token_probs": token_probs,
                "token_log_probs": token_log_probs,
            }
        )

    # Free memory on CUDA if possible
    try:
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    except Exception:
        pass

    return all_outputs
