import torch


class VICRegLoss(torch.nn.Module):
    def __init__(self, lmbd = 5e-3, u = 1, v= 1, epsilon = 1e-3) -> None:
        super().__init__()
        
        self.lmbd = lmbd
        self.u = u
        self.v = v
        self.epsilon = epsilon

    def forward(self, x, y) -> torch.Tensor:
        bs = x.size(0)
        emb = x.size(1)
    
        std_x = torch.sqrt(x.var(dim=0) + self.epsilon)
        std_y = torch.sqrt(y.var(dim=0) + self.epsilon)
        var_loss = torch.mean(torch.nn.functional.relu(1 - std_x)) + torch.mean(torch.nn.functional.relu(1 - std_y))
    
        invar_loss = torch.nn.functional.mse_loss(x, y)
    
        xNorm = (x - x.mean(0)) / x.std(0)
        yNorm = (y - y.mean(0)) / y.std(0)
        crossCorMat = (xNorm.T @ yNorm) / bs
        cross_loss = (crossCorMat * self.lmbd - torch.eye(emb, device=x.device) * self.lmbd).pow(2).sum()
        
        return self.u * var_loss + self.v * invar_loss + cross_loss


class HeadedVICReg(VICRegLoss):
    def __init__(self, head: torch.nn.Module, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.head = head

    def forward(self, x, y) -> torch.Tensor:
        x, y = self.head(x), self.head(y)

        return super().forward(x, y)