import re

import torch
from torch import nn
import torch.nn.functional as F

from utils import get_cont_distances


IGNORE_INDEX = -100


def get_lprobs(logits, target, attention_mask):
    """
    Get log probabilities from logits and target
    """
    lprobs = logits[..., :-1, :].log_softmax(-1)
    labels = target[..., 1:].clone()
    mask = attention_mask[..., 1:].to(lprobs.dtype)
    labels[labels < 0] = 0  # should be masked out anyway
    lprobs = lprobs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    # masked mean
    lprobs = (lprobs * mask).sum(-1) / mask.sum(-1)
    return lprobs


def get_sft_loss(logits, labels, attention_mask, label_smoothing: float = 0.0):
    """
    Get softmax loss
    """
    labels = labels.clone()
    labels[labels < 0] = IGNORE_INDEX
    if attention_mask is not None:
        shift_attention_mask = attention_mask[..., 1:]
        shift_logits = logits[..., :-1, :][
            shift_attention_mask.to(logits.device) != 0
        ].contiguous()
        shift_labels = labels[..., 1:][
            shift_attention_mask.to(labels.device) != 0
        ].contiguous()
    else:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    if label_smoothing > 0:
        loss_fct = nn.KLDivLoss(reduction="none")
        lprobs = shift_logits.log_softmax(-1)
        nll_loss = -lprobs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
        smooth_loss = -lprobs.mean(-1)
        loss = (1.0 - label_smoothing) * nll_loss + label_smoothing * smooth_loss
    else:
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1).to(shift_logits.device),
        )
    return loss


def is_number_regex(input_str):
    # return bool(re.fullmatch(r"[0-9]+", s))
    flag = bool(re.match(r"^[0-9]+$", input_str))
    if not flag:
        return False
    return str(int(input_str)) == input_str


def l2_dist(x, y):
    B, L = x.shape
    # shape: x (BL) y (V)
    dist = (x.reshape(-1)[:, None] - y[None, :]) ** 2
    dist = dist.reshape(B, L, -1)
    return dist


def get_digit_loss(
    tokenizer,
    logits,
    labels,
    attention_mask=None,
    target_temperature=1.0,
    beta: float = 1.0,
    use_place_weighting: bool = True,
):
    # shift
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    digits = [(k, v) for k, v in tokenizer.vocab.items() if is_number_regex(k)]
    digit_values = (
        torch.tensor([int(k) for k, v in digits])
        .to(shift_labels.device)
        .to(logits.dtype)
    )
    digit_indices = torch.tensor([v for k, v in digits]).to(shift_labels.device)

    mask = shift_labels[..., None] == digit_indices[None, None, :]
    any_mask = mask.any(-1).flip(-1)

    x = any_mask.long()
    # Compute the cumulative sum
    cumsum = torch.cumsum(x, dim=-1)

    # Create a mask to identify positions where cumulative sum should reset
    reset_mask = x == 0

    # Compute the differences in the cumulative sum at reset points
    reset_diff = torch.zeros_like(cumsum)
    reset_diff[:, 1:] = cumsum[:, :-1] * reset_mask[:, 1:].to(dtype=cumsum.dtype)
    # Compute the cumulative sum again after reset
    adjusted_cumsum = cumsum - reset_diff.cummax(-1).values
    pos = adjusted_cumsum.flip(-1)

    data_targets = mask.long().argmax(-1)
    data_targets = digit_values[data_targets.reshape(-1)].reshape(*data_targets.shape)
    digit_targets = l2_dist(data_targets, digit_values)
    digit_targets = (-digit_targets / target_temperature).softmax(-1)

    digit_logits = shift_logits[..., digit_indices]
    loss_fn = nn.KLDivLoss(reduction="none")
    loss = loss_fn(digit_logits.log_softmax(-1), digit_targets)  # BLV
    loss = loss.sum(-1)  # definition of KL

    loss_coeff = pos.to(loss.dtype)
    if not use_place_weighting:
        loss_coeff = (pos > 0).to(loss.dtype)
    loss = (loss * loss_coeff).sum() / loss_coeff.sum()
    return beta * loss


def get_digit_loss_with_cont(
    tokenizer,
    logits,
    labels,
    attention_mask=None,
    target_temperature=1.0,
    beta: float = 1.0,
    use_place_weighting: bool = True,
    contrastive_logits=None,
    contrastive_labels=None,
):
    assert contrastive_logits is not None
    assert contrastive_labels is not None

    # shift
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    shift_cont_logits = contrastive_logits[..., :-1, :].contiguous()
    shift_cont_labels = contrastive_labels[..., 1:].contiguous()

    cont_logits = shift_cont_logits.gather(-1, shift_labels[..., None])  # BL1

    digits = [(k, v) for k, v in tokenizer.vocab.items() if is_number_regex(k)]
    digit_values = (
        torch.tensor([int(k) for k, v in digits])
        .to(shift_labels.device)
        .to(logits.dtype)
    )
    digit_indices = torch.tensor([v for k, v in digits]).to(shift_labels.device)

    mask = shift_labels[..., None] == digit_indices[None, None, :]
    any_mask = mask.any(-1)

    # cont targets
    ts1 = shift_labels.clone()
    ts1[~any_mask] = digit_indices[0]
    # inverse_map with digit_indices
    _ts1 = (ts1[..., None] == digit_indices).nonzero(as_tuple=True)
    _ts1 = _ts1[2].reshape(ts1.shape)
    ts1 = digit_values[_ts1]

    ts2 = shift_cont_labels.clone()
    ts2[~any_mask] = digit_indices[0]
    # inverse_map with digit_indices
    _ts2 = (ts2[..., None] == digit_indices).nonzero(as_tuple=True)
    _ts2 = _ts2[2].reshape(ts2.shape)
    ts2 = digit_values[_ts2]

    cont_dist = (get_cont_distances(ts1, ts2) ** 2)[..., None]

    x = any_mask.flip(-1).long()
    # Compute the cumulative sum
    cumsum = torch.cumsum(x, dim=-1)

    # Create a mask to identify positions where cumulative sum should reset
    reset_mask = x == 0

    # Compute the differences in the cumulative sum at reset points
    reset_diff = torch.zeros_like(cumsum)
    reset_diff[:, 1:] = cumsum[:, :-1] * reset_mask[:, 1:].to(dtype=cumsum.dtype)
    # Compute the cumulative sum again after reset
    adjusted_cumsum = cumsum - reset_diff.cummax(-1).values
    pos = adjusted_cumsum.flip(-1)

    data_targets = mask.long().argmax(-1)
    data_targets = digit_values[data_targets.reshape(-1)].reshape(*data_targets.shape)
    digit_targets = l2_dist(data_targets, digit_values)
    digit_targets = torch.cat([digit_targets, cont_dist], dim=-1)
    digit_targets = (-digit_targets / target_temperature).softmax(-1)

    digit_logits = shift_logits[..., digit_indices]
    digit_logits = torch.cat([digit_logits, cont_logits], dim=-1)
    loss_fn = nn.KLDivLoss(reduction="none")
    loss = loss_fn(digit_logits.log_softmax(-1), digit_targets)  # BLV
    loss = loss.sum(-1)  # definition of KL

    loss_coeff = pos.to(loss.dtype)
    if not use_place_weighting:
        loss_coeff = (pos > 0).to(loss.dtype)
    loss = (loss * loss_coeff).sum() / loss_coeff.sum()
    return beta * loss


def get_digit_base_loss(
    tokenizer,
    logits,
    labels,
    attention_mask=None,
    target_temperature=2.0,
    beta: float = 1.0,
):
    # shift
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    digits = [(k, v) for k, v in tokenizer.vocab.items() if is_number_regex(k)]
    digit_indices = torch.tensor([v for k, v in digits]).to(shift_labels.device)

    mask = shift_labels[..., None] == digit_indices[None, None, :]
    any_mask = mask.any(-1)
    data_targets = mask.long().argmax(-1)

    digit_logits = shift_logits[..., digit_indices]
    loss_fn = nn.CrossEntropyLoss(reduction="none")
    V = digit_logits.shape[-1]
    loss = loss_fn(digit_logits.reshape(-1, V), data_targets.reshape(-1))  # BLV
    loss = loss.reshape(*any_mask.shape)
    any_mask = any_mask.to(loss.dtype)
    loss = (loss * any_mask).sum() / any_mask.sum()
    return beta * loss


if __name__ == "__main__":
    from transformers import AutoTokenizer
    from data.utils import encode

    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.2-1B-Instruct", use_fast=False
    )
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
    digits = [(k, v) for k, v in tokenizer.vocab.items() if is_number_regex(k)]
    digits = dict(digits)
    labels = torch.stack(
        [encode(tokenizer, v, split_digit=True) for v in ["x 400", "x 390"]]
    )
    logits = torch.randn(2, labels.shape[1], len(tokenizer))
    cont_logits = torch.randn(2, labels.shape[1], len(tokenizer))
    cont_labels = torch.stack(
        [encode(tokenizer, v, split_digit=True) for v in ["x 399", "x 388"]]
    )
    loss = get_digit_loss_with_cont(
        tokenizer,
        logits,
        labels,
        contrastive_logits=cont_logits,
        contrastive_labels=cont_labels,
    )
