import torch
import torch.nn as nn
import torch.nn.functional as F 

import repitl.kernel_utils as ku
import repitl.matrix_itl as itl
import repitl.difference_of_entropies as dent

import numpy as np

"""
Implementation for FroSSL
While this looks different from the simple psuedocode presented in the paper, it is equivalent
"""
class FroSSL_Loss(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.kernel_type = args.frossl_kernel_type
        self.alpha = args.frossl_alpha

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor):
        # normalize repr. along the batch dimension
        z_a = (z_a - z_a.mean(0)) / z_a.std(0) # NxD
        z_b = (z_b - z_b.mean(0)) / z_b.std(0) # NxD

        N = z_a.size(0)
        D = z_a.size(1)

        # # scale down to sqrt(d) sphere
        z_a =  (D**0.5) * z_a / torch.norm(z_a, dim=0)
        z_b =  (D**0.5) * z_b / torch.norm(z_b, dim=0)

        # calculate mse loss
        mse_loss = torch.nn.MSELoss()(z_a, z_b)

        # create kernel
        sigma = (z_a.shape[1] / 2) ** 0.5
        if self.kernel_type == 'gaussian':
            Ka = ku.gaussianKernel(z_a, z_a, sigma)
            Kb = ku.gaussianKernel(z_b, z_b, sigma)
        elif self.kernel_type == 'linear':
            Ka = (z_a.T @ z_a) / N
            Kb = (z_b.T @ z_b) / N
        else:
            raise NotImplementedError('Kernel type not implemented')

        # calculate entropy loss
        ent_Ka = itl.matrixAlphaEntropy(Ka, alpha=self.alpha)
        ent_Kb = itl.matrixAlphaEntropy(Kb, alpha=self.alpha)
        obj_entropy = ent_Ka + ent_Kb

        loss = -mse_loss + obj_entropy

        return -loss
    

"""
Adapted from: Barlow Twins: Self-Supervised Learning via Redundancy Reduction
"""
class Barlow_Twins_Loss(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.lambda_param = args.barlow_lambda

    def forward(self, z_a, z_b):
        # normalize repr. along the batch dimension
        z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD
        z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD

        N = z_a.size(0)
        D = z_a.size(1)

        # cross-correlation matrix
        c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD
        # loss
        c_diff = (c - torch.eye(D,device=z_a.device)).pow(2) # DxD
        # multiply off-diagonal elems of c_diff by lambda
        c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param
        loss = c_diff.sum()

        return loss


"""
Taken from: TiCo: Transformation Invariance and Covariance Contrast for Self-Supervised Visual Representation Learning

Description: kind of like an alternative form of barlow twins
"""
class TiCo_Loss(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.beta = 0.99
        self.rho = 2
        self.C = None

    def forward(self, z_1, z_2):
        if (self.C is None):
            self.C = torch.zeros((z_1.shape[1], z_1.shape[1]), device=z_1.device, requires_grad=True)

        z_1 = F.normalize(z_1, dim = 1)
        z_2 = F.normalize(z_2, dim = 1)

        B = (z_1.T @ z_1) / z_1.shape[0]
        self.new_C = self.beta * self.C + (1 - self.beta) * B
        loss = -(z_1 * z_2).sum(dim=1).mean() + self.rho * ((z_1 @ self.new_C) * z_1).sum(dim=1).mean()

        self.C = self.new_C.detach()

        return loss

"""
Taken from: A Simple Framework for Contrastive Learning of Visual Representations

Description: A batch-type method for contrastive loss
"""
class SimCLR_Loss(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.temperature = 0.2
        self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")

    def forward(self, z_a, z_b, **kwargs):
        batch_size = z_a.shape[0]
    
        z_a = F.normalize(z_a, p=2, dim=1)
        z_b = F.normalize(z_b, p=2, dim=1)

        diag_mask = torch.eye(batch_size, device=z_a.device, dtype=torch.bool)

        # calculate similiarities
        # here n = batch_size and m = batch_size * world_size
        # the resulting vectors have shape (n, m)
        logits_00 = torch.einsum("nc,mc->nm", z_a, z_a) / self.temperature
        logits_01 = torch.einsum("nc,mc->nm", z_a, z_b) / self.temperature
        logits_10 = torch.einsum("nc,mc->nm", z_b, z_a) / self.temperature
        logits_11 = torch.einsum("nc,mc->nm", z_b, z_b) / self.temperature

        # remove simliarities between same views of the same image
        logits_00 = logits_00[~diag_mask].view(batch_size, -1)
        logits_11 = logits_11[~diag_mask].view(batch_size, -1)

        # concatenate logits
        # the logits tensor in the end has shape (2*n, 2*m-1)
        logits_0100 = torch.cat([logits_01, logits_00], dim=1)
        logits_1011 = torch.cat([logits_10, logits_11], dim=1)
        logits = torch.cat([logits_0100, logits_1011], dim=0)

        # create labels
        labels = torch.arange(batch_size, device=z_a.device, dtype=torch.long)
        labels = labels.repeat(2)

        loss = self.cross_entropy(logits, labels)

        return loss

"""
Implementation adapted from solo learn library
"""
class VICREG_Loss(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.inv_loss_weight = 25.0
        self.var_loss_weight = 25.0
        self.cov_loss_weight = 1.0

    def compute_invariance_loss(self, z_a, z_b):
        return F.mse_loss(z_a, z_b)
    
    def compute_variance_loss(self, z_a, z_b):
        eps = 1e-4
        std_z1 = torch.sqrt(z_a.var(dim=0) + eps)
        std_z2 = torch.sqrt(z_b.var(dim=0) + eps)
        std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
        return std_loss
    
    def compute_covariance_loss(self, z_a, z_b):
        N, D = z_a.size()

        z1 = z_a - z_a.mean(dim=0)
        z2 = z_b - z_b.mean(dim=0)


        cov_z1 = (z1.T @ z1) / (N - 1)
        cov_z2 = (z2.T @ z2) / (N - 1)

        diag = torch.eye(D, device=z1.device)
        cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D
        return cov_loss
    
    def forward(self, z_a, z_b):
        inv_loss = self.compute_invariance_loss(z_a, z_b)
        var_loss = self.compute_variance_loss(z_a, z_b)
        cov_loss = self.compute_covariance_loss(z_a, z_b)

        loss = inv_loss*self.inv_loss_weight + var_loss*self.var_loss_weight + cov_loss*self.cov_loss_weight
        return loss


"""
CorInfoMax entropy loss: 
Adapted from: https://github.com/serdarozsoy/corinfomax-ssl/blob/c0b917ebbf80ad272f62da398b3f1c05245a89bf/imagenet100_resnet18/pretrain/loss.py#L7

I've made some aesthetic changes to mesh with the codebase
"""
class CorInfoMax(nn.Module):
    def __init__(self, args):
        super().__init__()
        
        # these are default values in original codebase
        self.R_ini = 1
        self.la_R = 0.01
        self.la_mu = 0.0
        proj_output_dim = args.projector_dim
        self.R_eps_weight = 1e-8 # from paper
        self.alpha = 1000 # from paper

        self.R = self.R_ini*torch.eye(proj_output_dim,  requires_grad=False)
        self.mu = torch.zeros(proj_output_dim, requires_grad=False)

    def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
        # move things to correct device
        self.mu = self.mu.to(z1.device)
        self.R = self.R.to(z1.device)

        # useful constants
        la_R = self.la_R 
        la_mu = self.la_mu
        N, D = z1.size()

        # get change in mu
        z = torch.cat((z1, z2), 0)
        mu_update = torch.mean(z, 0)
        self.new_mu = la_mu*(self.mu) + (1-la_mu)*(mu_update)

        # get change in R
        z_hat =  z - self.new_mu
        R_update = (z_hat.T @ z_hat) / (2*N)
        self.new_R = la_R*(self.R) + (1-la_R)*(R_update)

        # calculate loss
        regularized_new_R = self.new_R + self.R_eps_weight*torch.eye(D, device=z1.device)
        cov_loss = -torch.logdet(regularized_new_R) / D
        invariance_loss = self.alpha * F.mse_loss(z1, z2)
        total_loss = cov_loss + invariance_loss

        # This is required because new_R updated with backward.
        self.R = self.new_R.detach()
        self.mu = self.new_mu.detach()

        return total_loss