import torch
import torch.nn.functional as F
from torch import Tensor


def get_feature_loss(loss_name: str):
    if loss_name == 'l2':
        def loss_fn(x1: Tensor, x2: Tensor):
            return F.mse_loss(x1, x2, reduction='mean')
    elif loss_name == 'l2_norm':
        def loss_fn(x1: Tensor, x2: Tensor):
            return F.mse_loss(x1, x2, reduction='mean') / x2.pow(2).mean()
    elif loss_name == 'cosine':
        def loss_fn(x1: Tensor, x2: Tensor):
            x1 = x1.flatten(0, -2)
            x2 = x2.flatten(0, -2)
            numerator = torch.bmm(x1.unsqueeze(-2), x2.unsqueeze(-1)).view(-1)
            denominator = x1.norm(dim=-1) * x2.norm(dim=-1)
            return (numerator / denominator).mean()
    
    return loss_fn
