"""
DMLLossCosine.py

This module implements a custom PyTorch loss function called DMLLossCosine.

Usage:
- Create an instance of DMLLossCosine.
- Call the instance with embeddings and target tensors to compute the loss.

"""

import torch
import numpy as np
import pandas as pd
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision
from tqdm.notebook import tqdm
import logging
import random

class DMLLossCosine(nn.Module):
    """
    DMLLossCosine - Deep Metric Learning Loss using Cosine Similarity.
    
    This loss function measures the cosine similarity between embeddings and
    computes the Mean Squared Error (MSE) loss with the target distances.

    Attributes:
        similarity (nn.CosineSimilarity): Cosine similarity module.

    Methods:
        forward: Computes the forward pass of the loss function.
        embeddings_distance_pairs: Computes pairs of embeddings and their distances.

    """
    def __init__(self):
        super(DMLLossCosine, self).__init__()
        self.similarity = nn.CosineSimilarity(dim=-1, eps=1e-7) 
        
    def forward(self, embeddings, target):
        """
        Computes the forward pass of the loss function.

        Parameters:
            embeddings (torch.Tensor): Input embeddings tensor.
            target (torch.Tensor): Target tensor.

        Returns:
            loss (torch.Tensor): Computed loss.

        """
        first, second, distance = self.embeddings_distance_pairs(embeddings=embeddings, target=target)
        distance = distance.to("cuda")
        score = self.similarity(first, second)
        return nn.MSELoss()(score, distance)
    
    def embeddings_distance_pairs(self, embeddings, target):
        """
        Computes pairs of embeddings and their distances.

        Parameters:
            embeddings (torch.Tensor): Input embeddings tensor.
            target (torch.Tensor): Target tensor.

        Returns:
            first (torch.Tensor): First tensor containing embeddings.
            second (torch.Tensor): Second tensor containing embeddings.
            distance (torch.Tensor): Tensor containing distances.

        """
        distance = []
        first = []
        second = []
        for i, val1 in enumerate(target):
            for j, val2 in enumerate(target[i:]):
                if val1 == val2:
                    distance.append(torch.tensor(1.))
                else:
                    distance.append(torch.tensor(0.))   
                first.append(embeddings[i])
                second.append(embeddings[i+j])

        return torch.stack(first), torch.stack(second), torch.stack(distance)
    

if __name__ == "__main__":
    random.seed(10)
    embeddings = torch.rand((128, 64))
    target = torch.randint(low=0, high=4, size=(128,))
    loss = DMLLossCosine(device="cpu")
    

    
    
    

    
    



