# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import Iterable, Set, Union, Optional, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def get_model_device_dtype(model: AutoModelForCausalLM):
    p = next(model.parameters())
    return p.device, p.dtype


def set_seed(seed: int = 2026) -> None:
    import random
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_device(explicit: str | None = None) -> str:
    if explicit is not None:
        return explicit
    return "cuda" if torch.cuda.is_available() else "cpu"


def get_eos_token_ids(
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    eos_override: Optional[Union[int, Iterable[int]]] = None,
) -> Set[int]:
    ids: Set[int] = set()
    if eos_override is not None:
        if isinstance(eos_override, Iterable) and not isinstance(eos_override, (str, bytes)):
            ids.update(int(x) for x in eos_override)
        else:
            ids.add(int(eos_override))
    for src in [getattr(tokenizer, "eos_token_id", None), getattr(model.config, "eos_token_id", None)]:
        if src is None:
            continue
        if isinstance(src, Iterable) and not isinstance(src, (str, bytes)):
            ids.update(int(x) for x in src)
        else:
            ids.add(int(src))
    return ids


def build_position_ids(attn_mask: torch.Tensor) -> torch.Tensor:
    pos = attn_mask.long().cumsum(dim=1) - 1
    pos.masked_fill_(attn_mask == 0, 0)
    return pos