import os
import sys
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from .Buffer import Buffer
# check the path:
from .base import *

# https://arxiv.org/abs/2106.14413


class Co2L(ContinualLearning):
    # https://github.com/HobbitLong/SupContrast/blob/master/losses.py
    def __init__(self,
                 encoder: nn.Module,
                 lr=0.001,
                 temperature: float = 0.07,
                 lambda_ird: float = 0.1,
                 cls_output_dim: int = 2,
                 device='cuda') -> None:
        encoder_relu_linear = nn.Sequential(
            encoder, nn.ReLU(), nn.Linear(512, cls_output_dim))
        super(Co2L, self).__init__(encoder_relu_linear, lr)
        self.lambda_ird = lambda_ird
        self.temperature = temperature
        self.past_encoder = encoder_relu_linear
        self.buffer = Buffer(capacity=2000, device=device)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return z

    def add_to_buffer(self, examples: torch.Tensor, logits: torch.Tensor):
        self.buffer.add_data_logits(examples, logits)
        return

    def compute_sup_contrastive_loss(self,
                                     z: torch.Tensor,
                                     labels: torch.Tensor) -> torch.Tensor:
        """
        Compute the asymmetric supervised contrastive loss. 
        """
        # normalize z:
        z = F.normalize(z, dim=1)

        # Define temperature for contrastive loss
        temperature = self.temperature

        # Get the embeddings and labels
        batch_size = z.size(0)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(z.device)
        anchor_dot_contrast = torch.div(
            torch.matmul(z, z.T),
            temperature
        )
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        mask = mask * (1 - torch.eye(batch_size, batch_size).to(z.device))
        # Exponentiate the logits and apply the mask
        exp_logits = torch.exp(
            logits) * (1 - torch.eye(batch_size, batch_size).to(z.device))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # Calculate mean log probability for positive pairs
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # Calculate the final loss
        loss = -mean_log_prob_pos.nanmean()

        return loss

    def compute_ird_loss(self,
                         z: torch.Tensor,
                         z_past: torch.Tensor) -> torch.Tensor:
        """
        Compute the Instance-wise Relation Distillation (IRD) loss.
        """
        z = F.normalize(z, dim=1)
        z_past = F.normalize(z_past, dim=1)

        similarity_current = torch.mm(z, z.T) / self.temperature
        similarity_past = torch.mm(z_past, z_past.T) / self.temperature

        log_prob_current = F.log_softmax(similarity_current, dim=1)
        prob_past = F.softmax(similarity_past, dim=1)

        ird_loss = F.kl_div(log_prob_current, prob_past, reduction='batchmean')
        return ird_loss

    def compute_loss(self,
                     x: torch.Tensor,
                     labels: torch.Tensor,
                     augmentation: torchvision.transforms.Compose) -> torch.Tensor:
        # concat the samples from the buffer
        self.optimizer.zero_grad()
        if not self.buffer.is_empty():
            buf_inputs, _, buf_labels = self.buffer.get_data(
                x.size(0),
                transform=augmentation)
            # concatenate the buffer samples with the current samples
            x = torch.cat((x, buf_inputs), dim=0)
            labels = torch.cat((labels, buf_labels), dim=0)
        z = self.encoder(x)
        sup_contrastive_loss = self.compute_sup_contrastive_loss(z, labels)

        if self.past_encoder is not None:
            with torch.no_grad():
                z_past = self.past_encoder(x)
            ird_loss = self.compute_ird_loss(z, z_past)
        else:
            ird_loss = 0.0
        # class_loss = criterion(predictor(z), labels)

        total_loss = sup_contrastive_loss + self.lambda_ird * ird_loss
        total_loss.backward()
        self.optimizer.step()
        self.buffer.add_data(examples=x, labels=labels)
        return total_loss.item()


# if __name__ == "__main__":
#     import torchvision.models as models
#     resnet_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
#     resnet_transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms(antialias=True)

#     encoder_past = torch.nn.Sequential(*(list(resnet_model.children())[:-1]), torch.nn.Flatten())

#     model = Co2L(encoder_past, past_encoder=encoder_past)
#     # predictor = nn.Linear(512, 2)
#     x = torch.randn(10, 3, 28, 28)
#     # criterion = nn.CrossEntropyLoss()
#     labels = torch.randint(0, 2, (10,))
#     print(labels)
#     loss = model.compute_loss(x, labels)
#     print(loss)
