import torch
import torch.nn.functional as F
import torch.nn as nn


def compute_rec_loss(output, target):
    """
    Compute reconstruction loss for classification tasks.

    Args:
        output (torch.Tensor): Model output logits with shape (batch_size, sequence_length, num_classes)
        target (torch.Tensor): Target labels with shape (batch_size, sequence_length)

    Returns:
        torch.Tensor: Cross-entropy loss between output and target
    """
    output = output.reshape(-1, output.size(-1))
    target = target.reshape(-1).long()

    loss = F.cross_entropy(output, target)
    return loss


def compute_error_rates(output, target):
    """
    Compute error rates for classification tasks.

    Args:
        output (torch.Tensor): Model output logits with shape (batch_size, sequence_length, num_classes)
        target (torch.Tensor): Target labels with shape (batch_size, sequence_length)

    Returns:
        tuple: (total_errors, position_errors)
            - total_errors (float): Overall error rate across all positions
            - position_errors (torch.Tensor): Error rate for each position in the sequence
    """
    predicted_classes = output.argmax(dim=-1)

    target = target.transpose(0, 1).long()

    position_errors = (predicted_classes != target).float().mean(dim=1)

    total_errors = (predicted_classes != target).float().mean()

    return total_errors.item(), position_errors


def loss_function(recon_x, x, mean, logvar, prior_mean, prior_logvar):
    """
    Compute the loss function for a Variational Autoencoder (VAE).

    Args:
        recon_x (torch.Tensor): Reconstructed input from the decoder
        x (torch.Tensor): Original input data
        mean (torch.Tensor): Mean of the latent distribution
        logvar (torch.Tensor): Log variance of the latent distribution
        prior_mean (torch.Tensor): Mean of the prior distribution
        prior_logvar (torch.Tensor): Log variance of the prior distribution

    Returns:
        tuple: (reconstruction_loss, kl_divergence_loss)
            - reconstruction_loss (torch.Tensor): MSE loss between reconstructed and original data
            - kl_divergence_loss (torch.Tensor): KL divergence between posterior and prior distributions
    """
    recon_loss = nn.MSELoss()(recon_x, x)

    mean = torch.clamp(mean, min=-10, max=10)
    logvar = torch.clamp(logvar, min=-10, max=10)

    kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp().clamp(min=1e-10))

    return recon_loss, kl_div


def focal_loss(logits, targets, alpha=1, gamma=2):
    """
    Compute focal loss for classification tasks.

    Focal loss addresses class imbalance by down-weighting easy examples and focusing on hard examples.

    Args:
        logits (torch.Tensor): Raw logits from the model with shape (batch_size, num_classes)
        targets (torch.Tensor): Target labels with shape (batch_size,)
        alpha (float): Weighting factor to balance positive/negative examples (default: 1.0)
        gamma (float): Focusing parameter to down-weight easy examples (default: 2.0)

    Returns:
        torch.Tensor: Computed focal loss value
    """
    ce_loss = F.cross_entropy(logits, targets, reduction="none")
    pt = torch.exp(-ce_loss)

    focal_loss = alpha * (1 - pt) ** gamma * ce_loss
    return focal_loss.mean()


def weighted_label_smoothing_loss(pred, target, smoothing=0.1, class_weights=None):
    """
    Compute weighted label smoothing loss for classification tasks.

    Label smoothing helps prevent overconfidence by softening the target distribution.
    This version also supports class weights to handle imbalanced datasets.

    Args:
        pred (torch.Tensor): Model predictions with shape (batch_size, num_classes)
        target (torch.Tensor): Target labels with shape (batch_size,)
        smoothing (float): Smoothing factor for label smoothing (default: 0.1)
        class_weights (list or torch.Tensor, optional): Weights for each class to handle imbalance

    Returns:
        torch.Tensor: Computed weighted label smoothing loss value
    """
    num_classes = pred.size(1)

    target_one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)

    target_smooth = target_one_hot * (1 - smoothing) + smoothing / num_classes

    log_prob = F.log_softmax(pred, dim=1)

    if class_weights is not None:
        class_weights_tensor = torch.FloatTensor(class_weights).to(pred.device)
        weights = class_weights_tensor[target]
        loss = -weights * (target_smooth * log_prob).sum(dim=1)
    else:
        loss = -(target_smooth * log_prob).sum(dim=1)

    return loss.mean()

    if class_weights is not None:
        class_weights_tensor = torch.FloatTensor(class_weights).to(pred.device)
        weights = class_weights_tensor[target]
        loss = -weights * (target_smooth * log_prob).sum(dim=2)
    else:
        loss = -(target_smooth * log_prob).sum(dim=2)

    return loss.mean()
