from jax import numpy as jnp
from torch.utils.data import DataLoader
import jax
from flax import linen as nn


def mse_loss(y_pred, y_true):
    """
    Args:
        y_pred: jnp.array(B)
        y_true: jnp.array(B)
    """
    return jnp.mean((y_pred - y_true) ** 2)


def cross_entropy_loss(logits, target):
    target = nn.one_hot(target, num_classes=logits.shape[-1])
    loss = jnp.einsum("BH,BH->B", target, nn.log_softmax(logits, axis=-1))
    loss = jnp.mean(loss, axis=-1)
    return -loss


def cross_entropy_loss_lm(logits, target, ignore_index=-100):
    """
    Args:
        logits: jnp.array(BLH)
        target: jnp.array(BL, dtype=long)
        ignore_index: must be a negative value
    """
    num_valid = (target != ignore_index).sum(axis=-1)
    # Indices outside the range [0, num_classes) will be encoded as zeros:
    target = nn.one_hot(target, num_classes=logits.shape[-1])
    loss = jnp.einsum("BLH,BLH->BL", target, nn.log_softmax(logits, axis=-1))
    loss = jnp.sum(loss, axis=-1) / num_valid  # mean reduction on sequene level
    loss = jnp.mean(loss, axis=-1)
    return -loss


def cross_entropy_loss_nmt(logits, target, ignore_index=-100):
    """
    Args:
        logits: jnp.array(BLH)
        target: jnp.array(BL, dtype=long)
        ignore_index: must be a negative value
    """
    num_valid = (target != ignore_index).sum(axis=-1)
    # Indices outside the range [0, num_classes) will be encoded as zeros:
    target = nn.one_hot(target, num_classes=logits.shape[-1])
    loss = jnp.einsum("BLH,BLH->BL", target, nn.log_softmax(logits, axis=-1))
    loss = jnp.sum(loss, axis=-1) / num_valid  # mean reduction on sequene level
    loss = jnp.mean(loss, axis=-1)
    return -loss


def smoothed_corss_entropy(logits, target, ignore_index=-100):
    eps = 0.1
    lprobs = nn.log_softmax(logits, axis=-1)
    not_mask = target != ignore_index
    num_valid = not_mask.sum(axis=-1)
    # Indices outside the range [0, num_classes) will be encoded as zeros:
    target = nn.one_hot(target, num_classes=logits.shape[-1])
    nll_loss = -jnp.einsum("BLH,BLH->BL", target, lprobs)
    smooth_loss = -lprobs.sum(axis=-1)

    nll_loss = jnp.where(not_mask, nll_loss, 0.0)
    smooth_loss = jnp.where(not_mask, smooth_loss, 0.0)

    nll_loss = nll_loss.sum(axis=-1)  # sum accross seq len
    smooth_loss = smooth_loss.sum(axis=-1)  # sum across seq len
    eps_i = eps / (lprobs.shape[-1] - 1)
    loss = (1.0 - eps - eps_i) * nll_loss + eps_i * smooth_loss
    loss = loss / num_valid  # mean across seq len
    return loss.mean(axis=-1)  # mean batch size


def translation_loss(logits, target, ignore_index=-100, smoothing=True):
    """
    Args:
        logits: jnp.array(BLH)
        target: jnp.array(BL, dtype=long)
    """
    if smoothing:
        return smoothed_corss_entropy(logits, target, ignore_index=ignore_index)

    return cross_entropy_loss_nmt(
        logits=logits, target=target, ignore_index=ignore_index
    )
