import re

import torch
import torch.nn as nn


class Loss:
    def __init__(self, tokenizer, loss_type: str = "sft", ignore_index: int = -100):
        self.tokenizer = tokenizer
        self.loss_type = loss_type
        self.ignore_index = ignore_index

    def __call__(self, outputs, labels, num_items_in_batch: int = 1):
        labels = (
            outputs.labels
        )  # label inserts visual at given positions, thereby modifying the original labels
        # attention_mask = labels != self.ignore_index
        logits = outputs.logits.float()
        loss_sft = get_sft_loss(logits, labels, attention_mask=None)
        if self.loss_type == "sft":
            return loss_sft
        elif self.loss_type == "digit":
            loss_digit = get_digit_loss(
                self.tokenizer,
                logits,
                labels,
                target_temperature=2.0,
                beta=0.05,
            )
            loss = loss_sft + loss_digit
            return loss
        elif self.loss_type == "digit_base":
            loss_digit = get_digit_base_loss(
                self.tokenizer,
                logits,
                labels,
                target_temperature=2.0,
                beta=0.05,
            )
            loss = loss_sft + loss_digit
            return loss
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")


def is_number_regex(input_str):
    return bool(re.match(r"^[0-9]+$", input_str))


def l2_dist(x, y):
    B, L = x.shape
    # shape: x (BL) y (V)
    dist = (x.reshape(-1)[:, None] - y[None, :]) ** 2
    dist = dist.reshape(B, L, -1)
    return dist


def get_sft_loss(logits, labels, attention_mask):
    """
    Get softmax loss
    """

    if attention_mask is not None:
        shift_attention_mask = attention_mask[..., 1:]
        shift_logits = logits[..., :-1, :][
            shift_attention_mask.to(logits.device) != 0
        ].contiguous()
        shift_labels = labels[..., 1:][
            shift_attention_mask.to(labels.device) != 0
        ].contiguous()
    else:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1).to(shift_logits.device),
    )
    return loss


def get_digit_loss(
    tokenizer,
    logits,
    labels,
    target_temperature=2.0,
    beta: float = 0.1,
):
    labels = labels.clone()
    # mask out all not between 529 "<" and 829 "</""
    starts = labels == 529
    ends = labels == 829
    mask = torch.cumsum(starts.long() - ends.long(), dim=-1)
    labels[~mask.bool()] = -100
    # shift
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    digits = [(k, v) for k, v in tokenizer.get_vocab().items() if is_number_regex(k)]
    digit_values = (
        torch.tensor([int(k) for k, v in digits])
        .to(shift_labels.device)
        .to(logits.dtype)
    )
    digit_indices = torch.tensor([v for k, v in digits]).to(shift_labels.device)

    mask = shift_labels[..., None] == digit_indices[None, None, :]
    any_mask = mask.any(-1)

    x = any_mask.flip(-1).long()
    # Compute the cumulative sum
    cumsum = torch.cumsum(x, dim=-1)

    # Create a mask to identify positions where cumulative sum should reset
    reset_mask = x == 0

    # Compute the differences in the cumulative sum at reset points
    reset_diff = torch.zeros_like(cumsum)
    reset_diff[:, 1:] = cumsum[:, :-1] * reset_mask[:, 1:].to(dtype=cumsum.dtype)
    # Compute the cumulative sum again after reset
    adjusted_cumsum = cumsum - reset_diff.cummax(-1).values
    pos = adjusted_cumsum.flip(-1)

    data_targets = mask.long().argmax(-1)
    data_targets = digit_values[data_targets.reshape(-1)].reshape(*data_targets.shape)
    digit_targets = l2_dist(data_targets, digit_values)
    digit_targets = (-digit_targets / target_temperature).softmax(-1)

    digit_logits = shift_logits[..., digit_indices]
    loss_fn = nn.KLDivLoss(reduction="none")
    loss = loss_fn(digit_logits.log_softmax(-1), digit_targets)  # BLV
    loss = loss.sum(-1)  # definition of KL

    loss_coeff = pos.to(loss.dtype)
    # loss_coeff = (pos > 0).to(loss.dtype)  # do not use pos
    loss = (loss * loss_coeff).sum() / loss_coeff.sum()

    return beta * loss


def get_digit_base_loss(
    tokenizer,
    logits,
    labels,
    target_temperature=2.0,
    beta: float = 1.0,
):
    # shift
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    digits = [(k, v) for k, v in tokenizer.get_vocab().items() if is_number_regex(k)]
    digit_indices = torch.tensor([v for k, v in digits]).to(shift_labels.device)

    mask = shift_labels[..., None] == digit_indices[None, None, :]
    data_targets = mask.long().argmax(-1)

    digit_logits = shift_logits[..., digit_indices]
    loss_fn = nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
    # loss = loss_fn(digit_logits, data_targets)  # BLV
    loss = loss_fn(digit_logits.view(-1, digit_logits.size(-1)), data_targets.view(-1))
    loss = loss.view(*data_targets.shape)

    mask = mask.any(-1).to(loss.dtype)
    mask = mask.to(loss.dtype)
    loss = (loss * mask).sum() / mask.sum()

    return beta * loss


if __name__ == "__main__":
    from transformers import AutoProcessor

    processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
    tokenizer = processor.tokenizer
    # tokenizer.add_tokens(["<b>", "</b>"])
    loss_fn = Loss(tokenizer, loss_type="digit")
    conv = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "<image0>\n<task>Stack objects in this order <scene><p>red letter V</p> at <b>(0.254, 0.570), {0.094, 0.156}</b>.\n<p>green paisley letter V</p> at <b>(0.500, 0.570), {0.094, 0.156}</b>.\n<p>rainbow letter V</p> at <b>(0.746, 0.570), {0.094, 0.156}</b>.</scene> <scene><p>red letter V</p> at <b>(0.254, 0.570), {0.094, 0.156}</b>.\n<p>rainbow letter V</p> at <b>(0.746, 0.578), {0.094, 0.133}</b>.\n<p>green paisley letter V</p> at <b>(0.746, 0.516), {0.094, 0.148}</b>.</scene> <scene><p>rainbow letter V</p> at <b>(0.746, 0.602), {0.094, 0.086}</b>.\n<p>green paisley letter V</p> at <b>(0.746, 0.531), {0.094, 0.125}</b>.\n<p>red letter V</p> at <b>(0.746, 0.469), {0.094, 0.156}</b>.</scene>.</task>\nEvery action you take must include two locations in the format of <b>(x, y)</b> and one clockwise rotation angle in the format of <r>[r]</r>. The first location is the image coordinate where you use a suction cup to pick up the object, and the second location is where you place the object.The image coordinate ranges from 0 to 1. The rotation angle indicates how many degrees you rotate the object clockwise, and it ranges from -359 to 359.\nYou have finished: Step 1: Pick up the object at <b>(0.465, 0.617)</b>, rotate <r>[0]</r> degrees, and drop it at <b>(0.711, 0.617)</b>.",
                }
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": "Step 2: Pick up the <p>red letter V</p> at <b>(0.281, 0.633)</b>, rotate <r>[0]</r> degrees, and drop it at <b>(0.770, 0.633)</b>.",
                }
            ],
        },
    ]
    text = processor.apply_chat_template(conv)
    prefix = processor.apply_chat_template(conv[:1], add_generation_prompt=True)
    prefix_len = len(tokenizer.tokenize(prefix))
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs.input_ids
    labels = input_ids.clone()
    labels[:, :prefix_len] = -100
    logits = torch.randn(1, input_ids.shape[1], len(tokenizer))
    outputs = type("Outputs", (object,), {"logits": logits, "labels": labels})
    loss = loss_fn(outputs, labels)
