import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, resnet18

torch.backends.cudnn.enabled = False



def nlfda2_loss(z1, z2):   # [batch_size, dim]
    penal_para = 0.8 #[1,5,10]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    batch_size, dim = z1.shape
    n=2*batch_size
    # print()
    # print('dim',dim)

    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    z = torch.cat([z1, z2], dim=0) # dimension: (2*batch_size,dim)
    
    Sb = torch.zeros((dim, dim), dtype=torch.float32, device=device)# Between-class scatter matrix
    Sw = torch.zeros((dim, dim), dtype=torch.float32, device=device)# Within-class scatter matrix


    zc_indices = torch.cat([torch.arange(batch_size, device=device), torch.arange(batch_size, device=device) + batch_size])
    zc = z[zc_indices].view(batch_size, 2, dim)

    A = torch.ones((batch_size, 2, 2), device=device)
    zc1 = torch.sum(zc, dim=1, keepdim=True)
    colSums_A = torch.sum(A, dim=2, keepdim=True)
    Z = torch.einsum('bij,bjk->bik', zc.transpose(1, 2), (colSums_A * zc)) - torch.einsum('bij,bjk->bik', torch.einsum('bij,bjk->bik', zc.transpose(1, 2), A), zc)
    Sb = torch.sum((Z / n).view(batch_size, dim, dim) + zc.transpose(1, 2) @ zc * (1 - 2.0 / n) + zc1.transpose(1, 2) @ zc1 / n, dim=0)
    Sw = torch.sum((Z / 2.0).view(batch_size, dim, dim), dim=0)


    z1 = torch.sum(z, dim=0, keepdim=True)
    Sb = Sb -  z1.t() @ z1 / n - Sw
    Sb = (Sb + Sb.t()) / 2.0  # Final between-class scatter matrix
    Sw = (Sw + Sw.t()) / 2.0  # Final within-class scatter matrix

    eye_mat = penal_para * torch.eye(Sw.shape[0], dtype=Sb.dtype, device=device)
    eye_mat = eye_mat.to('cuda')
    B = Sw + eye_mat  # Make sure this is positive definite
    temp = torch.linalg.pinv(B) @ Sb 
    temp=(temp + temp.t()) / 2.0 
    # print('sb',Sb.shape, 'max_mean_sb', torch.max(Sb),torch.mean(Sb))
    # print('sw',Sw.shape, 'max_mean_sw', torch.max(Sw),torch.mean(Sw))
    # print('temp_max_mean',torch.max(temp),torch.mean(temp))

    evals_sum = torch.sum(torch.diagonal(temp))
    # Calculate loss

    loss = -torch.mean(evals_sum)

    return torch.mean(loss)



class projection_MLP(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super().__init__()
        hidden_dim = in_dim
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, out_dim)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x 

class SimCLR(nn.Module):

    def __init__(self, backbone=resnet50()):
        super().__init__()
        
        self.backbone = backbone
        self.projector = projection_MLP(backbone.output_dim)
        self.encoder = nn.Sequential(
            self.backbone,
            self.projector
        )

    def forward(self, x1, x2):
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)
        loss = nlfda2_loss(z1,z2)
        return {'loss':loss}


















