# -*- coding: UTF-8 -*-
"""
Wraps the BYOL (Bootstrap Your Own Latent) model with SGHMC named as NRCC to provide a consistent interface for training and evaluating the model.

The BYOL model is a self-supervised learning algorithm that learns visual representations without using labeled data. This wrapper provides a 
high-level interface to train and evaluate the BYOL model on a given dataset.

Args:
    base_encoder (nn.Module): The base encoder network to use for the BYOL model.
    projection_head (nn.Module): The projection head network to use for the BYOL model.
    device (torch.device): The device to use for the BYOL model.
    **kwargs: Additional keyword arguments to pass to the BYOL model.
"""


import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torchvision.transforms as transforms


from utils.gather_layer import GatherLayer
from models.moco.moco_wrapper import MoCoWrapper


class BYOLWrapper(MoCoWrapper):
    """
    The `BYOLWrapper` class is a PyTorch module that implements the Bootstrap Your Own Latent (BYOL) with SGHMC as a self-supervised learning approach. It takes an encoder network and other hyperparameters as input, and provides methods to compute the BYOL loss and update the key encoder network.
    
    The class has the following key methods:
    - `forward_loss`: Computes the BYOL contrastive loss and cluster loss for a batch of query and key images.
    - `forward`: Computes the BYOL loss and updates the key encoder network.
    - `forward_v2`: An alternative forward method that computes the BYOL loss using a different approach.
    - `_dequeue_and_enqueue`: Updates the feature queue used in the BYOL loss computation.

    Bootstrap Your Own Latent A New Approach to Self-Supervised Learning
    https://github.com/lucidrains/byol-pytorch/tree/master/byol_pytorch
    """

    def __init__(self,
                 encoder,
                 num_cluster,
                 in_dim,
                 temperature,
                 hidden_size=4096,
                 fea_dim=256,
                 byol_momentum=0.999,
                 symmetric=True,
                 shuffling_bn=True,
                 latent_std=0.001,
                 queue_size=0):
        """
        Initializes a BYOL (Bootstrap Your Own Latent) wrapper module.
        
        Args:
            encoder (nn.Module): The encoder module to use.
            num_cluster (int): The number of clusters.
            in_dim (int): The input dimension.
            temperature (float): The temperature parameter for the softmax.
            hidden_size (int, optional): The hidden size of the projector. Defaults to 4096.
            fea_dim (int, optional): The feature dimension. Defaults to 256.
            byol_momentum (float, optional): The momentum for the key encoder. Defaults to 0.999.
            symmetric (bool, optional): Whether to use a symmetric BYOL. Defaults to True.
            shuffling_bn (bool, optional): Whether to use shuffling batch normalization. Defaults to True.
            latent_std (float, optional): The standard deviation of the latent space. Defaults to 0.001.
            queue_size (int, optional): The size of the queue. Defaults to 0.
        """
        nn.Module.__init__(self)
        
        self.symmetric = symmetric
        self.m = byol_momentum
        self.shuffling_bn = shuffling_bn
        self.num_cluster = num_cluster
        self.temperature = temperature
        self.fea_dim = fea_dim
        self.latent_std = latent_std
        self.queue_size = queue_size

        # create the encoders
        self.encoder_q = encoder
        self.projector_q = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, fea_dim)
        )
        self.encoder_k = copy.deepcopy(self.encoder_q)
        self.projector_k = copy.deepcopy(self.projector_q)

        self.predictor = nn.Sequential(nn.Linear(fea_dim, hidden_size),
                                       nn.BatchNorm1d(hidden_size),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(hidden_size, fea_dim)
                                       )
        self.q_params = list(self.encoder_q.parameters()) + list(self.projector_q.parameters())
        self.k_params = list(self.encoder_k.parameters()) + list(self.projector_k.parameters())

        self.SGHMC = SGHMC()

        for param_q, param_k in zip(self.q_params, self.k_params):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        for m in self.predictor.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if hasattr(m.bias, 'data'):
                    m.bias.data.fill_(0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d,
                                nn.GroupNorm, nn.SyncBatchNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.encoder = nn.Sequential(self.encoder_k, self.projector_k)
        if self.queue_size > 0:
            # create the queue
            self.register_buffer("queue", torch.randn(queue_size, fea_dim))  # tinyimagenet
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
            self.register_buffer("queue_ind", torch.zeros(queue_size, dtype=torch.long))
        else:
            # create the queue
            self.register_buffer("queue", None)  # tinyimagenet
            self.register_buffer("queue_ptr", None)
            self.register_buffer("queue_ind", None)

    def q_distr(self, x, y, normalize=True, temperature=1):
        """
        Computes the similarity between two sets of feature vectors using cosine similarity.
        
        Args:
            x (torch.Tensor): A tensor of feature vectors, with shape (batch_size, feature_dim).
            y (torch.Tensor): Another tensor of feature vectors, with shape (batch_size, feature_dim).
            normalize (bool, optional): If True, the feature vectors are L2-normalized before computing the similarity. Defaults to True.
            temperature (float, optional): A temperature parameter to scale the similarity values. Defaults to 1.
        
        Returns:
            torch.Tensor: A tensor of similarity values, with shape (batch_size,).
        """

        cos_sim = nn.CosineSimilarity(dim=2)
        if normalize:
            x = x / (x.norm(dim=1, keepdim=True) + 1e-10)
            y = y / (y.norm(dim=1, keepdim=True) + 1e-10)
        sim = cos_sim(x.unsqueeze(1), y.unsqueeze(0))#.sum(dim=1)
        return (1 / (1 + sim**2))
    
    def Resize_image(self, img, sze):
        """
        Resizes the input image to the specified size using bilinear interpolation.
        
        Args:
            img (torch.Tensor): The input image tensor to be resized.
            sze (int): The desired size (width and height) of the resized image.
        
        Returns:
            torch.Tensor: The resized image tensor.
        """

        resized_image = F.interpolate(img, size=(sze, sze), mode='bilinear', align_corners=False)
        return resized_image
    
    def cal_transform(self, im_k, im_q):
        """
        Calculates a transformation on the input images `im_k` and `im_q` using the SGHMC (Stochastic Gradient Hamiltonian Monte Carlo) method.
        
        Args:
            im_k (torch.Tensor): The first input image.
            im_q (torch.Tensor): The second input image.
        
        Returns:
            torch.Tensor: The transformed image.
        """

        X_V1 = self.SGHMC(lambda x : torch.log(self.q_distr(self.encoder_k(self.Resize_image(x, im_k.size(2))), self.encoder_k(im_q))), self.Resize_image(im_k, 32)).detach()
        X_V1 = self.Resize_image(X_V1, im_k.size(2)).detach().cuda()
        # print(im_k.size(), X_V1.size())
        return X_V1
    
    def sghmc_loss(self, q1, q2, s):
        """
        Computes the SGHMC (Stochastic Gradient Hamiltonian Monte Carlo) loss for a given set of query vectors `q1`, `q2`, and `s`.
        
        The loss is calculated as follows:
        1. Compute the similarity between `q1` and `q2` using matrix multiplication and normalization.
        2. Compute the similarity between `q1` and `s` using matrix multiplication and normalization.
        3. Compute the log-pdf of the positive and negative similarities using the `torch.logsumexp` function.
        4. The final loss is the negative of the mean of the log-pdf of the positive similarities, plus the mean of the log-pdf of the negative similarities.
        
        Args:
            q1 (torch.Tensor): The first set of query vectors.
            q2 (torch.Tensor): The second set of query vectors.
            s (torch.Tensor): The set of vectors to compute negative similarities against.
        
        Returns:
            torch.Tensor: The SGHMC loss.
        """

        sim1 = torch.div(torch.matmul(q1, q2.T), 0.5)
        
        sim2 = torch.div(torch.matmul(q1, s.T), 0.5)

        log_pdf_d = torch.logsumexp(sim1,dim=1, keepdim=True)
        log_pdf_n = torch.logsumexp(sim2,dim=1, keepdim=True)
        loss = -0.0*log_pdf_d.mean() + log_pdf_n.mean()

        return loss


    def forward_k(self, im_k, psedo_labels):
        """
        Computes the key embeddings for the BYOL model.

        Args:
            im_k (torch.Tensor): The input image tensor for the key network.
            psedo_labels (torch.Tensor): The pseudo-labels for the input image.

        Returns:
            torch.Tensor: The normalized key embeddings.
            torch.Tensor: The concatenated key embeddings from all GPUs.
        """
        with torch.no_grad():  # no gradient to keys
            if self.shuffling_bn:
                # shuffle for making use of BN
                im_k_, idx_unshuffle = self._batch_shuffle_ddp(im_k)
                k = self.encoder_k(im_k_)  # keys: NxC
                k = k.float()
                k = self.projector_k(k)
                k = nn.functional.normalize(k, dim=1)
                # undo shuffle
                k = self._batch_unshuffle_ddp(k, idx_unshuffle)
            else:
                k = self.encoder_k(im_k)  # keys: NxC
                k = self.projector_k(k)
                k = nn.functional.normalize(k, dim=1)

            k = k.detach_()
            all_k = self.concat_all_gather(k)
            
        return k, all_k

    def forward_loss(self, im_q, im_k, psedo_labels: torch.Tensor):
        """
        Computes the forward loss for the BYOL (Bootstrap Your Own Latent) algorithm.
        
        Args:
            im_q (torch.Tensor): The query image tensor.
            im_k (torch.Tensor): The key image tensor.
            psedo_labels (torch.Tensor): The pseudo-labels for the images.
        
        Returns:
            tuple: A tuple containing the following elements:
                - contrastive_loss (torch.Tensor): The contrastive loss.
                - all_q (torch.Tensor): The normalized query features.
                - all_k (torch.Tensor): The normalized key features.
        """

        q = self.encoder_q(im_q)  # queries: NxC
        q = self.projector_q(q)

        batch_psedo_labels = psedo_labels
        batch_all_psedo_labels = self.concat_all_gather(batch_psedo_labels)
        k, all_k = self.forward_k(im_k, batch_all_psedo_labels)

        noise_q = q + torch.randn_like(q) * self.latent_std

        # contrastive_loss = (2 - 2 * F.cosine_similarity(self.predictor(noise_q), k)).mean()
        contrastive_loss = - 2 * F.cosine_similarity(self.predictor(noise_q), k).mean()
        all_q = F.normalize(torch.cat(GatherLayer.apply(q), dim=0), dim=1)
  
        return contrastive_loss,  all_q, all_k

    def forward(self, im_q, im_k, indices, momentum_update=True, v2=True):
        """
        Performs the forward pass of the BYOL (Bootstrap Your Own Latent) model.
        
        Args:
            im_q (torch.Tensor): The query image tensor.
            im_k (torch.Tensor): The key image tensor.
            indices (torch.Tensor): The indices of the images.
            momentum_update (bool, optional): Whether to perform momentum update on the key encoder. Defaults to True.
            v2 (bool, optional): Whether to use the v2 forward pass. Defaults to True.
        
        Returns:
            tuple: The contrastive loss, the query features, and the key features.
        """
        
        if v2:
            return self.forward_v2(im_q, im_k, indices, momentum_update=momentum_update)

        psedo_labels = self.psedo_labels[indices]
        if self.symmetric:
            contrastive_loss1,  q1, k1 = self.forward_loss(im_q, im_k, psedo_labels)
            contrastive_loss2,  q2, k2 = self.forward_loss(im_k, im_q, psedo_labels)
            contrastive_loss = 0.5 * (contrastive_loss1 + contrastive_loss2)

            q = torch.cat([q1, q2], dim=0)
            k = torch.cat([k1, k2], dim=0)
        else:  # asymmetric loss
            contrastive_loss,  q, k = self.forward_loss(im_q, im_k, psedo_labels)

        if momentum_update:
            # update the key encoder
            with torch.no_grad():  # no gradient to keys
                self._momentum_update_key_encoder()

        if self.queue_size > 0:
            indices = self.concat_all_gather(indices)
            if self.symmetric:
                indices = indices.repeat(2)
            self._dequeue_and_enqueue(k, indices)

        return contrastive_loss,  q

    def forward_v2(self, im_q_, im_k_, indices, momentum_update=True):
        """
        Performs the forward pass of the BYOL-SGHM model.
        
        Args:
            im_q_ (torch.Tensor): The query images.
            im_k_ (torch.Tensor): The key images.
            indices (torch.Tensor): The indices of the images.
            momentum_update (bool, optional): Whether to update the key encoder with momentum. Defaults to True.
        
        Returns:
            tuple: The contrastive loss and the normalized query features.
        """

        if momentum_update:
            # update the key encoder
            with torch.no_grad():  # no gradient to keys
                self._momentum_update_key_encoder()
        im_s = self.cal_transform(im_q_, im_k_)
    
        # compute query features
        s = self.encoder_q(im_s)
        s = self.projector_q(s).detach()
        # s = self.encoder_k(im_s)
        # s = self.projector_k(s).detach()
        
        s = F.normalize(s, dim=1)

        im_q = torch.cat([im_q_, im_k_], dim=0)
        im_k = torch.cat([im_k_, im_q_], dim=0)

        psedo_labels = self.psedo_labels[indices]
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = self.projector_q(q)
        
        q1, q2 = q.chunk(2, dim=0)
        #1
        sghmc_loss = self.sghmc_loss(F.normalize(q1), F.normalize(q2), s)
        # sghmc_loss = self.sghmc_loss(F.normalize(q1), F.normalize(q2), s.detach())

        batch_psedo_labels = psedo_labels
        batch_all_psedo_labels = self.concat_all_gather(batch_psedo_labels)
        k,  all_k = self.forward_k(im_k, batch_all_psedo_labels.repeat(2))
        


        noise_q = q + torch.randn_like(q) * self.latent_std

        contrastive_loss = (2 - 2 * F.cosine_similarity(self.predictor(noise_q), k)).mean()
        all_q = F.normalize(torch.cat(GatherLayer.apply(q), dim=0), dim=1)
        

        q1, q2 = all_q.chunk(2, dim=0)
         
        # all_s = F.normalize(torch.cat(GatherLayer.apply(s), dim=0), dim=1)
        # sghmc_loss = self.sghmc_loss(q1, q2, all_s)
        contrastive_loss += 0.1*sghmc_loss

        return contrastive_loss, all_q

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys, indices):
        """
        Updates the queue of keys and indices by dequeuing the oldest keys and enqueuing the new keys.
        
        Args:
            keys (torch.Tensor): The new keys to be enqueued.
            indices (torch.Tensor): The indices corresponding to the new keys.
        
        Returns:
            None
        """

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[ptr:ptr + batch_size] = keys
        self.queue_ind[ptr:ptr + batch_size] = indices
        ptr = (ptr + batch_size) % self.queue_size  # move pointer

        self.queue_ptr[0] = ptr

    
    
        
      
class SGHMC:
    """
    Implements a Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) sampler.
    
    The SGHMC sampler is used to draw samples from a target distribution defined by the provided `log_pdf` function. It simulates stochastic Hamiltonian dynamics to generate the samples.
    
    Args:
        log_pdf (callable): A function that takes a tensor of parameters and returns the log-probability of the target distribution.
        init (torch.Tensor): The initial parameters to start the sampling process.
    
    Returns:
        torch.Tensor: The sampled parameters.
    """
    

    def __init__(self):
        """
        Initializes the hyperparameters for the Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) sampler.
        Attributes:
            epsilon (float): The small constant added to the denominator for numerical stability.
            alpha (float): The momentum coefficient used to update the target network.
            gamma (float): The temperature parameter used in the contrastive loss.
            L (int): The number of layers in the BYOL network.
            clip (float): The maximum value for gradient clipping.
            bs (int): The batch size used for training.
        """
        self.epsilon = 0.05
        self.alpha = 0.99
        self.gamma = 0.1
        self.L = 1
        self.clip = 1.0
        self.bs = 16
        

    def __get_noise__(self, count):
        return torch.randn_like(count)

    def __call__(self, log_pdf, init):
        n_size = init.shape[0]
        out = torch.empty((0, init.shape[1], init.shape[2], init.shape[3])).cuda()
        for i in range(int(torch.ceil(torch.tensor(n_size) / self.bs))):
            q = init[i * self.bs:(i + 1) * self.bs].detach().clone().cuda().requires_grad_(True)
            p = torch.randn_like(q)

            # Simulate stochastic Hamiltonian dynamics
            for i in range(self.L):
                # Compute gradient of log_pdf w.r.t q
                log_pdf_grad = torch.autograd.grad(log_pdf(q).sum(), q, create_graph=False)[0]
                clipped_grad_norm = torch.clamp(log_pdf_grad, -self.clip, self.clip).detach()
                noise = self.__get_noise__(clipped_grad_norm)
                
                p = p - self.epsilon * clipped_grad_norm - self.gamma * p + self.alpha * noise
                
                # Update q using p
                q = q + self.epsilon * p
                
                

            out = torch.cat((out, q.detach().clone()), dim=0)
        return out.cpu()
