import torch
from torch.nn import functional

def discrete_loss(logits, labels, ignore_label=None, weight=None, label_smoothing=0.0, reduction='mean'):
    k = logits.size(logits.ndim - 1)
    logits = logits.reshape(-1, k)
    labels = labels.reshape(-1, )
    loss = functional.cross_entropy(logits, labels, weight=weight, label_smoothing=label_smoothing, 
                                    ignore_index=ignore_label, reduction=reduction)
    return loss

def continuous_loss(inputs, targets, ignore_value=None, reduction='mean'):
    if ignore_value is not None:
        inputs = inputs[targets != ignore_value]
        targets = targets[targets != ignore_value]
    loss = functional.mse_loss(inputs, targets, reduction=reduction)
    return loss

def action_prediction_loss(y, targets, unknown_action=None, actions_discrete=True, 
                           weight=None, label_smoothing=0.0, reduction='mean'):
    all_unlabeled = False
    if unknown_action is not None:
        all_unlabeled = (targets == unknown_action).all()
    if (not all_unlabeled) and actions_discrete:
        loss = discrete_loss(y, targets, ignore_label=unknown_action, weight=weight, 
                             label_smoothing=label_smoothing, reduction=reduction)
    elif not all_unlabeled:
        loss = continuous_loss(y, targets, ignore_value=unknown_action, reduction=reduction)
    else:
        loss = torch.tensor(0.0, dtype=torch.float32, device=y.device)
    return loss

def var_loss(z):
    std_z = torch.sqrt(z.var(dim=0) + 1e-04)
    std_loss = functional.relu(1.0 - std_z)
    return std_loss.mean()

def covar_loss(z):
    batch_size, dim = z.shape
    diag_mask = torch.eye(dim, dtype=torch.bool)
    z = z - z.mean(dim=0)
    cov_z = torch.mm(z.permute(1, 0), z) / (batch_size - 1)
    cov_loss = cov_z[~diag_mask].pow(2)
    return cov_loss.sum() / dim

def vic_loss(za, zb, **params):
    dim = za.shape[-1]
    std_loss = var_loss(za.reshape(-1, dim))
    std_loss += var_loss(zb.reshape(-1, dim))
    std_loss *= params['eta']
    cov_loss = covar_loss(za.reshape(-1, dim))
    cov_loss += covar_loss(zb.reshape(-1, dim))
    cov_loss *= params['beta']
    sim_loss = functional.mse_loss(za, zb, reduction='mean')
    sim_loss *= params['gamma']
    vic_loss = sim_loss + std_loss + cov_loss
    return vic_loss