import torch.nn as nn

import torch
import torch.nn.functional as F


class AngularPenaltySMLoss(nn.Module):
    def __init__(self, args, loss_type='cosface', eps=1e-7, s=None, m=None):
        '''
        Angular Penalty Softmax Loss
        Three 'loss_types' available: ['arcface', 'sphereface', 'cosface']
        These losses are described in the following papers:

        ArcFace: https://arxiv.org/abs/1801.07698
        SphereFace: https://arxiv.org/abs/1704.08063
        CosFace/Ad Margin: https://arxiv.org/abs/1801.05599
        '''

        super(AngularPenaltySMLoss, self).__init__()
        loss_type = loss_type.lower()
        assert loss_type in ['arcface', 'sphereface', 'cosface', 'crossentropy']
        if loss_type == 'arcface':
            self.s = args.temperature if not s else s
            self.m = 0.2 if not m else m
        if loss_type == 'sphereface':
            self.s = args.temperature if not s else s
            self.m = 1.35 if not m else m
        if loss_type == 'cosface':
            self.s = args.temperature if not s else s
            self.m = 0.4 if not m else m
        self.loss_type = loss_type
        self.eps = eps
        self.args = args

        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, wf, labels):
        if self.loss_type == 'crossentropy':
            return self.cross_entropy(wf, labels)
        else:
            wf = wf / self.args.temperature
            if self.loss_type == 'cosface':
                numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
            if self.loss_type == 'arcface':
                numerator = self.s * torch.cos(torch.acos(
                    torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)) + self.m)
            if self.loss_type == 'sphereface':
                numerator = self.s * torch.cos(self.m * torch.acos(
                    torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)))

            excl = torch.cat([torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
            denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
            L = numerator - torch.log(denominator)
            return -torch.mean(L)


# def SupConLoss(features, labels, prototypes, temperature=0.1):
#     """
#     Supervised Contrastive Learning loss function.
#     Args:
#         features: tensor of shape (batch_size, feat_dim)
#         labels: tensor of shape (batch_size)
#         prototypes: tensor of shape (num_prototypes, feat_dim)
#         temperature: a temperature factor to scale the logits (default: 0.07)
#     Returns:
#         The computed SupCon loss
#     """
#     device = features.device
#
#     # Normalize the features and prototypes to the unit length
#     features = F.normalize(features, dim=1)
#     prototypes = F.normalize(prototypes, dim=1)
#
#     # Compute the cosine similarity matrix
#     expanded_features = torch.cat([features, prototypes], dim=0)
#     sim_matrix = torch.mm(expanded_features, expanded_features.T)
#
#
#     # Compute the mask for positive and negative samples
#     labels = labels.view(-1, 1)
#     mask_positive = torch.eq(labels, labels.T).float().to(device)
#     mask_negative = torch.ones_like(mask_positive).to(device) - mask_positive
#
#     # Compute the exponential of the similarity matrix scaled by temperature
#     exp_sim_matrix = torch.exp(sim_matrix / temperature)
#
#     # Apply the masks to the exponential of the similarity matrix
#     pos_exp_sim_matrix = (exp_sim_matrix * mask_positive)[:, :features.size(0)]
#     neg_exp_sim_matrix = (exp_sim_matrix * mask_negative)[:, features.size(0):]
#
#
#     # Sum over the positive (same label) elements for each instance in the batch
#     pos_sum = torch.sum(pos_exp_sim_matrix, dim=1,keepdim=True)
#
#
#     # Sum over the negative (different label) elements for each instance in the batch
#     neg_sum = torch.sum(neg_exp_sim_matrix, dim=1,keepdim=True)
#     # Compute the final loss
#     eps = 1e-8
#     loss = -torch.log((pos_sum + eps) / (pos_sum + neg_sum + eps))
#
#     # Average the loss over the batch
#     loss = loss.mean()
#     return loss
def SupConLoss(features, labels, prototypes, temperature=0.001):
    device = features.device

    # Normalize the features and prototypes to unit length
    features = F.normalize(features, dim=1)
    prototypes = F.normalize(prototypes, dim=1)

    # Compute cosine similarity matrices
    sim_matrix_features = torch.mm(features, features.T)
    sim_matrix_prototypes = torch.mm(features, prototypes.T)

    # Debug: Print dimensions to understand the mismatch
    # print("Features shape:", features.shape)  # Should be [batch_size, feat_dim]
    # print("Prototypes shape:", prototypes.shape)  # Should be [num_prototypes, feat_dim]
    # print("Sim matrix features shape:", sim_matrix_features.shape)  # Should be [batch_size, batch_size]
    # print("Sim matrix prototypes shape:", sim_matrix_prototypes.shape)  # Should be [batch_size, num_prototypes]

    # Ensure that labels are in the correct shape for comparison
    # Correctly creating the mask for positive samples
    labels = labels[:features.shape[0]]  # Make sure labels only match the features
    labels = labels.view(-1, 1)
    mask_positive = torch.eq(labels, labels.T).float().to(device)

    # Debug: Print mask dimensions
    # print("Mask positive shape:", mask_positive.shape)  # Should be [batch_size, batch_size]

    # Calculate the exponential of the similarity matrices scaled by temperature
    exp_sim_matrix_features = torch.exp(sim_matrix_features / temperature)
    exp_sim_matrix_prototypes = torch.exp(sim_matrix_prototypes / temperature)

    # Debug: Check before applying the mask
    # print("Exp sim matrix features shape:", exp_sim_matrix_features.shape)
    # print("Exp sim matrix prototypes shape:", exp_sim_matrix_prototypes.shape)

    # Apply the positive mask to the features similarity matrix
    positive_sims = exp_sim_matrix_features * mask_positive
    negative_sims_features = exp_sim_matrix_features * (1 - mask_positive)
    negative_sims_prototypes = exp_sim_matrix_prototypes  # All prototypes are considered negative

    positive_sum = torch.sum(positive_sims, dim=1, keepdim=True)
    negative_sum = torch.sum(negative_sims_features, dim=1, keepdim=True) + \
                   torch.sum(negative_sims_prototypes, dim=1, keepdim=True)

    eps = 1e-8
    loss = -torch.log((positive_sum + eps) / (positive_sum + negative_sum + eps))
    loss = loss.mean()
    return loss


def CosineDistanceSum(features, labels, prototypes):
    device = features.device

    # Normalize the features and prototypes to unit length
    features = F.normalize(features, dim=1)
    prototypes = F.normalize(prototypes, dim=1)

    # Compute cosine distance matrices
    sim_matrix_features = 1 - torch.mm(features, features.T)
    sim_matrix_prototypes = 1 - torch.mm(features, prototypes.T)

    # Create a mask for intra-feature distances where labels are different
    labels = labels[:features.shape[0]]
    labels = labels.view(-1, 1)
    mask_different_labels = 1 - torch.eq(labels, labels.T).float()

    # Use this mask to select distances where labels are different
    different_label_distances = sim_matrix_features * mask_different_labels

    # Sum of all cosine distances
    sum_cos_distances_features = torch.sum(sim_matrix_features) - torch.sum(torch.diag(sim_matrix_features))
    sum_cos_distances_prototypes = torch.sum(sim_matrix_prototypes)
    sum_different_label_distances = torch.sum(different_label_distances)

    # Total sum of all distances
    total_sum = sum_cos_distances_features + sum_cos_distances_prototypes + sum_different_label_distances

    return total_sum


def MaxMinCosineDistance(features, labels, prototypes):
    device = features.device
    features = F.normalize(features, dim=1)
    prototypes = F.normalize(prototypes, dim=1)

    # Compute cosine distance matrices
    cos_dist_features = 1 - torch.mm(features, features.T)
    cos_dist_prototypes = 1 - torch.mm(features, prototypes.T)

    # Create a mask for distances where labels are different
    labels = labels[:features.shape[0]]
    labels = labels.view(-1, 1)
    mask_different_labels = 1 - torch.eq(labels, labels.T).float()

    # Apply the mask to the feature distances (set same-label distances to zero)
    cos_dist_features = cos_dist_features * mask_different_labels

    # Remove self-distances by setting them very low (they are always 1)
    eye = torch.eye(cos_dist_features.size(0), device=device)
    cos_dist_features = cos_dist_features - eye

    # Find the minimum non-zero cosine distance in each row
    min_dist_features = torch.max(cos_dist_features, dim=1)[0]  # max because higher is better for distance
    min_dist_prototypes = torch.min(cos_dist_prototypes, dim=1)[0]

    # Maximizing the minimum distance across all distances
    min_distance = torch.min(torch.cat((min_dist_features, min_dist_prototypes)))

    # Convert to a loss (negative because we want to maximize this distance)
    loss = -min_distance
    return loss * 10


def ContrastiveMarginLoss(features, labels, prototypes, margin=0.5):
    device = features.device
    features = F.normalize(features, dim=1)
    prototypes = F.normalize(prototypes, dim=1)

    # Compute cosine similarity matrices
    sim_matrix_features = (1 - torch.mm(features, features.T))
    sim_matrix_prototypes = (1 - torch.mm(features, prototypes.T))

    # Create a mask for distances where labels are different
    labels = labels[:features.shape[0]]
    labels = labels.view(-1, 1)
    mask_different_labels = 1 - torch.eq(labels, labels.T).float()

    # Apply the mask to the feature distances (ignore same-label distances)
    sim_matrix_features = sim_matrix_features * mask_different_labels

    # Apply margin
    loss_features = F.relu(margin - sim_matrix_features)
    loss_prototypes = F.relu(margin - sim_matrix_prototypes)

    # Sum losses, ignoring the diagonal (self-comparisons in features)
    loss = torch.sum(loss_features) + torch.sum(loss_prototypes) - torch.sum(torch.diag(loss_features))
    return loss / features.numel()


def AngularLoss(features, labels, prototypes):
    device = features.device
    # Normalize the features and prototypes to unit length for accurate angular calculation
    features = F.normalize(features, dim=1)
    prototypes = F.normalize(prototypes, dim=1)

    # Compute cosine similarity matrices
    cosine_similarity_features = torch.mm(features, features.T)
    cosine_similarity_prototypes = torch.mm(features, prototypes.T)

    # Compute angular distances using the arccosine of the cosine similarities
    angular_dist_features = torch.acos(torch.clamp(cosine_similarity_features, -1.0, 1.0))
    angular_dist_prototypes = torch.acos(torch.clamp(cosine_similarity_prototypes, -1.0, 1.0))

    # Create a mask for angular distances where labels are different
    labels = labels[:features.shape[0]]
    labels = labels.view(-1, 1)
    different_labels_mask = 1 - torch.eq(labels, labels.T).float().to(device)

    # Apply the mask to angular distances between different features
    angular_dist_features = angular_dist_features * different_labels_mask

    # Sum of all angular distances between features with different labels
    sum_angular_distances_features = torch.sum(angular_dist_features)
    # Sum of all angular distances between features and prototypes
    sum_angular_distances_prototypes = torch.sum(angular_dist_prototypes)

    # The loss is the negative sum of these angular distances (we maximize angular distance)
    loss = - (sum_angular_distances_features + sum_angular_distances_prototypes)
    return loss


def prototypical_loss_cosine(features, labels, args, session):
    device = features.device
    labels = labels[:features.shape[0]]
    labels = labels - args.base_class - (session - 1) * args.way

    # Normalize the features to unit length for cosine similarity calculation
    features = F.normalize(features, dim=1)

    # Create prototypes by averaging features within each class
    unique_labels = torch.unique(labels)
    prototypes = torch.stack([features[labels == label].mean(0) for label in unique_labels])

    # Normalize prototypes to unit length
    prototypes = F.normalize(prototypes, dim=1)

    # Compute cosine similarity (as negative distance) between features and prototypes
    # Cosine similarity is the dot product of the normalized vectors
    cosine_similarity = torch.mm(features, prototypes.t())

    # Since we need a distance metric and cross_entropy expects larger values for correct classes,
    # we use negative similarity as a proxy for distance
    distances = -cosine_similarity

    # Compute the cross-entropy loss between the distances and the labels
    loss = F.cross_entropy(distances, labels)
    return loss

# def SupConLoss(features, labels, prototypes, temperature=1.0):
#     """
#     Supervised Contrastive Learning loss function where features are considered
#     as positive examples among themselves and as negative examples with respect
#     to the prototypes. The labels for features are assumed to be the same, indicating
#     a new class not represented by the prototypes.
#
#     Args:
#         features: tensor of shape (batch_size, feat_dim), new class examples.
#         labels: tensor of shape (batch_size), labels for the new class examples, assumed to be the same for all.
#         prototypes: tensor of shape (num_prototypes, feat_dim), existing classes examples.
#         temperature: a temperature factor to scale the logits (default: 1.0).
#
#     Returns:
#         The computed SupCon loss.
#     """
#     # device = features.device
#     #
#     # # Normalize the features and prototypes to the unit length
#     # features = F.normalize(features, dim=1)
#     # prototypes = F.normalize(prototypes, dim=1)
#     #
#     # # Compute the cosine similarity among features, and between features and prototypes
#     # sim_matrix_features = torch.mm(features, features.T) / temperature
#     # sim_matrix_prototypes = torch.mm(features, prototypes.T) / temperature
#     #
#     # # Create mask for positive and negative pairs
#     # labels = labels.view(-1, 1)
#     # mask_positive = torch.eq(labels, labels.T).float().to(device)
#     # mask_negative = torch.ones_like(sim_matrix_prototypes).to(device)
#     #
#     # # Compute the exponential of the similarity matrix
#     # exp_sim_matrix_features = torch.exp(sim_matrix_features)
#     # exp_sim_matrix_prototypes = torch.exp(sim_matrix_prototypes)
#     #
#     # # Apply masks to the exponential similarity matrix
#     # exp_sim_matrix_features *= mask_positive
#     # exp_sim_matrix_prototypes *= mask_negative
#     #
#     # # Sum over the positive elements for each instance in the batch
#     # pos_sum = torch.sum(exp_sim_matrix_features, dim=1, keepdim=True) - torch.exp(torch.tensor(1.0) / temperature)
#     #
#     # # Sum over the negative elements for each instance in the batch
#     # neg_sum = torch.sum(exp_sim_matrix_prototypes, dim=1, keepdim=True)
#     #
#     # # Compute the contrastive loss
#     # eps = 1e-8
#     # loss = -torch.log((pos_sum + eps) / (pos_sum + neg_sum + eps))
#     #
#     # # Average the loss over the batch
#     # loss = loss.mean()
#     #
#     # return loss
#     # labels = labels.view(-1, 1)
#     # batch_size = features.shape[0]
#     #
#     # # Create masks
#     # mask_positive = torch.ones(batch_size, batch_size, device=features.device) - torch.eye(batch_size,
#     #                                                                                        device=features.device)
#     #
#     # # Features similarity
#     # sim_matrix = torch.mm(features, features.t())
#     # sim_matrix = torch.div(sim_matrix, temperature)
#     # exp_sim_matrix = torch.exp(sim_matrix)
#     # pos_exp_sim_matrix = exp_sim_matrix * mask_positive
#     # sum_all = torch.sum(exp_sim_matrix, dim=1, keepdim=True)
#     # sum_positive = torch.sum(pos_exp_sim_matrix, dim=1, keepdim=True)
#     # log_score_positive = torch.log((sum_positive + 1e-8) / (sum_all - sum_positive + 1e-8))
#     #
#     # # Prototypes similarity
#     # sim_matrix_prototypes = torch.mm(features, prototypes.t())
#     # sim_matrix_prototypes = torch.div(sim_matrix_prototypes, temperature)
#     # exp_sim_matrix_prototypes = torch.exp(sim_matrix_prototypes)
#     #
#     # # Loss computation
#     # sum_negative = torch.sum(exp_sim_matrix_prototypes, dim=1, keepdim=True)
#     # log_score_negative = torch.log(sum_negative + 1e-8)
#     #
#     # log_score = log_score_positive + log_score_negative
#     # loss = -torch.mean(log_score)
#     # return loss
