"""
Contrastive Learning Loss Functions

Implements SimCLR (Simple Contrastive Learning of Representations) loss
for both standard ANNs and Spiking Neural Networks.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from util import accuracy


class SimCLRLoss(nn.Module):
    """
    SimCLR contrastive loss for standard ANNs.
    
    Given N samples, creates 2N augmented views. For each view, the other
    view of the same sample is the positive pair, and all other views
    are negatives.
    
    Loss = -log(exp(sim(z_i, z_j)/τ) / Σ_k exp(sim(z_i, z_k)/τ))
    
    Args:
        batch_size: Number of samples per batch
        n_views: Number of augmented views per sample (default: 2)
        temperature: Temperature scaling parameter τ (default: 0.07)
    """
    def __init__(self, batch_size, n_views, temperature):
        super(SimCLRLoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss().cuda()
        self.batch_size = batch_size
        self.n_views = n_views
        self.temperature = temperature

    def forward(self, features):
        # Create labels: positive pairs have the same label
        labels = torch.cat([torch.arange(self.batch_size) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.cuda()

        # Normalize features
        features = F.normalize(features, dim=1)

        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T)

        # Remove self-similarity (diagonal elements)
        mask = torch.eye(labels.shape[0], dtype=torch.bool).cuda()
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

        # Select positives and negatives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        # Compute logits and labels for cross-entropy
        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # Apply temperature scaling
        logits = logits / self.temperature
        loss = self.criterion(logits, labels)

        # Compute accuracy metrics
        top1, top5 = accuracy(logits, labels, topk=(1, 5))

        return loss, top1, top5


class SimCLRSNNLoss(nn.Module):
    """
    SimCLR contrastive loss adapted for Spiking Neural Networks.
    
    Key adaptations for SNNs:
    1. Processes features across multiple timesteps
    2. Randomly samples a subset of timesteps for training diversity
    3. Averages loss across selected timesteps
    
    This temporal augmentation strategy helps SNNs learn robust
    representations across different temporal dynamics.
    
    Args:
        timestep: Maximum number of simulation timesteps
        batch_size: Number of samples per batch
        n_views: Number of augmented views per sample
        temperature: Temperature scaling parameter
    """
    def __init__(self, timestep, batch_size, n_views, temperature):
        super(SimCLRSNNLoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss().cuda()
        self.timestep = timestep
        self.batch_size = batch_size
        self.n_views = n_views
        self.temperature = temperature

    def forward(self, features):
        # Create labels
        labels = torch.cat([torch.arange(self.batch_size) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.cuda()

        # Random temporal sampling: use timesteps 1 to set_timestep
        set_timestep = torch.randint(1, self.timestep + 1, (1,)).item()
        features = features[:, :set_timestep, :]

        losses = top1s = top5s = 0

        # Process each timestep
        for cur_timestep in range(set_timestep):
            cur_features = features[:, cur_timestep, :].squeeze(1)
            cur_features = F.normalize(cur_features, dim=1)

            similarity_matrix = torch.matmul(cur_features, cur_features.T)

            mask = torch.eye(labels.shape[0], dtype=torch.bool).cuda()
            cur_labels = labels[~mask].view(labels.shape[0], -1)
            similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

            positives = similarity_matrix[cur_labels.bool()].view(cur_labels.shape[0], -1)
            negatives = similarity_matrix[~cur_labels.bool()].view(similarity_matrix.shape[0], -1)

            logits = torch.cat([positives, negatives], dim=1)
            cur_labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

            logits = logits / self.temperature
            loss = self.criterion(logits, cur_labels)
            losses += loss

            top1, top5 = accuracy(logits, cur_labels, topk=(1, 5))
            top1s += top1
            top5s += top5

        # Average across timesteps
        losses = losses / set_timestep
        top1s = top1s / set_timestep
        top5s = top5s / set_timestep

        return losses, top1s, top5s
