import torch
import torch.nn.functional as F
import yaml


def get_model_identifiers_from_yaml(model_family):
    # path is model_configs.yaml
    """
    models:
        llama2-7b:
            hf_key: "NousResearch/Llama-2-7b-chat-hf"
            question_start_tag: "[INST] "
            question_end_tag: " [/INST] "
            answer_tag: ""
            start_of_sequence_token: "<s>"
    """
    model_configs = {}
    with open(f"egu/config/{model_family}.yaml", "r") as f:
        model_config = yaml.load(f, Loader=yaml.FullLoader)
    return model_config


def load_yaml(file_path):
    with open(file_path, "r") as file:
        return yaml.safe_load(file)


def target_logprob(model, input_ids, attention_mask, labels):
    """
    Returns per-example log-prob of the label sequence.
    Expects labels with -100 where tokens are ignored (e.g., prompt/pad).
    """
    # Forward (Trainer already moved tensors; don’t re-move here)
    out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    logits = out.logits  # [B, T, V]

    # Causal shift
    logits = logits[:, :-1, :].contiguous()  # [B, T-1, V]
    labels = labels[:, 1:].contiguous().to(dtype=torch.long)  # [B, T-1]

    device = logits.device
    labels = labels.to(device)

    valid = labels.ne(-100)
    # Safe gather: replace ignored positions with 0 (not used after masking)
    labels_safe = labels.masked_fill(~valid, 0)

    logprobs = F.log_softmax(logits, dim=-1)  # [B, T-1, V]
    tok_lp = logprobs.gather(-1, labels_safe.unsqueeze(-1)).squeeze(-1)  # [B, T-1]
    tok_lp = torch.where(valid, tok_lp, torch.zeros_like(tok_lp))

    return tok_lp.sum(dim=-1)  # [B]


def target_logprob_single(model, input_ids, attention_mask, labels):
    """
    Computes the log-probability of the sequence defined by `labels`.
    - input_ids: full input prompt+completion
    - labels:    same shape, with -100 for masked-out positions (e.g. prompt)
    """

    device = next(model.parameters()).device

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    logits = outputs.logits  # [B, L, V]

    logprobs = F.log_softmax(logits, dim=-1)

    active_mask = labels != -100  # [B, L]
    labels_safe = labels.clone()
    labels_safe[~active_mask] = 0  # these positions won't contribute

    token_logprobs = logprobs.gather(2, labels_safe.unsqueeze(-1)).squeeze(-1)  # [B, L]

    # sum only active positions
    seq_logprob = (token_logprobs * active_mask).sum(-1)  # [B]

    return seq_logprob  # [B]


def causal_logprob(model, input_ids, attention_mask):
    """
    Computes the **log-probability of the entire sequence** under `model`.
    Returns shape (B,) – one value per example.
    """
    outputs = model(
        input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
    )  # causal LM trick
    # The built-in loss is mean-CE over *non-padding* tokens
    # We need the **per-token** log-softmax to sum ourselves.
    logits = outputs.logits  # (B, L, V)
    logprobs = F.log_softmax(logits, dim=-1)

    # Shift so each token predicts the *next* token
    logprobs = logprobs[:, :-1]  # (B, L-1, V)
    labels = input_ids[:, 1:]  # (B, L-1)
    mask = attention_mask[:, 1:]  # (B, L-1)

    # Gather log prob for the target token at every position
    token_logprobs = logprobs.gather(2, labels.unsqueeze(-1)).squeeze(-1)

    # Sum over tokens *actually present* (mask==1)
    seq_logprob = (token_logprobs * mask).sum(-1)  # shape (B,)

    return seq_logprob
