from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F


def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss


def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
    # print(logits.size())
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss


def EmbeddingDistanceLoss(
    hidden_states: torch.Tensor, labels: torch.LongTensor, embedding: torch.nn.Embedding, num_items_in_batch: int = None, ignore_index=-100, **kwargs
):
    # Shift so that tokens < n predict n
    hidden_states = hidden_states[..., :-1, :]
    labels = labels[..., 1:]

    # Compute mask for loss
    mask = labels != ignore_index

    # Compute loss
    if mask.any():
        loss = F.l1_loss(hidden_states[mask], embedding(labels[mask]).detach(), reduction="none").mean(-1)
        if num_items_in_batch is not None:
            loss = loss.sum() / num_items_in_batch
        else:
            loss = loss.mean()
    else:
        loss = 0.0

    return loss


def EmbeddingSimilarityL2Loss(
    hidden_states: torch.Tensor, labels: torch.LongTensor, embedding: torch.nn.Embedding, num_items_in_batch: int = None, ignore_index=-100, **kwargs
):
    # Shift so that tokens < n predict n
    hidden_states = hidden_states[..., :-1, :]
    labels = labels[..., 1:]

    # Compute mask for loss
    mask = labels != ignore_index

    # Compute loss
    if mask.any():
        loss = (1 - F.cosine_similarity(hidden_states[mask], embedding(labels[mask]).detach())) ** 2

        if num_items_in_batch is None:
            loss = loss.mean()
        else:
            loss = loss.sum() / num_items_in_batch
    else:
        loss = 0.0

    return loss


# === Losses ===

LM_LOSSES = {
    'lm': ForCausalLMLoss,
}

NORM_LOSSES = {
    'l1': EmbeddingDistanceLoss,
}

SIM_LOSSES = {
    'cos_l2': EmbeddingSimilarityL2Loss,
}

# ===
