import torch
import torch.nn as nn
import torch.nn.functional as F 
from torchvision.models import resnet50
from .backbones import projection_MLP


class BarlowTwinsLoss(torch.nn.Module):

    def __init__(self, device, lambda_param=5e-3):
        super(BarlowTwinsLoss, self).__init__()
        self.lambda_param = lambda_param
        self.device = device

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor):
        # normalize repr. along the batch dimension
        z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD
        z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD

        N = z_a.size(0)
        D = z_a.size(1)

        # cross-correlation matrix
        c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD
        # loss
        c_diff = (c - torch.eye(D, device=self.device)).pow(2) # DxD
        # multiply off-diagonal elems of c_diff by lambda
        c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param
        loss = c_diff.sum()

        return loss


class BarlowTwins(nn.Module):
    def __init__(self, backbone, device, lambda_param=5e-3):
        super().__init__()
        
        self.backbone = backbone
        self.feature_dim = backbone.output_dim
        self.projector = projection_MLP(in_dim=backbone.output_dim, hidden_dim=2048, out_dim=2048)

        self.encoder = nn.Sequential( # f encoder
            self.backbone,
            self.projector
        )
        # self.predictor = prediction_MLP()
        self.criterion = BarlowTwinsLoss(device=device, lambda_param=lambda_param)
    
    def forward(self, x1, x2):
        f1, f2 = self.backbone(x1), self.backbone(x2)
        z1, z2 = self.projector(f1), self.projector(f2)
        L = self.criterion(z1, z2)
        return {'loss': L, "feature": (f1, f2)}


