import torch
from torch import nn

# for arc face 
# from __future__ import print_function
# from __future__ import division
# import torch
# import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
from typing import Optional

# def get_entropy(output):
#   prob = F.softmax(output,dim=1)
#   entropy = torch.sum(prob * torch.log(prob + 1e-5), dim=1)
  
#   return -torch.mean(entropy) 
def get_entropy(output):
  prob = F.softmax(output,dim=1)
  entropy = torch.sum(prob * torch.log(prob + 1e-5), dim=1)
  
  return -torch.mean(entropy) 

def get_entropy_local(output):
  """output: local feature (N, C, # of classes) 
  """
  ###get local feature and returns entropy of each local features
  prob = F.softmax(output,dim=2)
  entropy = torch.sum(prob * torch.log(prob + 1e-5), dim=2)
  return -entropy

def orthogonality_loss(features, labels):
    """
    バッチ内の特徴が直行性を持つように損失を計算。
    同一ラベル間の特徴は高い類似度を持ち、異なるラベル間の特徴は低い類似度を持つ。

    Args:
        features (torch.Tensor): 画像特徴 (batch_size, feature_dim)
        labels (torch.Tensor): ラベル (batch_size)

    Returns:
        torch.Tensor: 損失値
    """

    cosine_similarity_matrix = torch.mm(features, features.T)

    label_matrix = labels.unsqueeze(1) == labels.unsqueeze(0)  
    label_matrix = label_matrix.float()

    positive_loss = (1 - cosine_similarity_matrix) ** 2 * label_matrix

    negative_loss = cosine_similarity_matrix ** 2 * (1 - label_matrix)

    loss = positive_loss.sum() + negative_loss.sum()
    return loss / (features.size(0) ** 2) 

def entropy_local_topk(p, label, num_of_local_feature, top_k=3):
    """
    Extract non-Top-K regions and calculate entropy.
    """
    label_repeat = label.repeat_interleave(num_of_local_feature)
    p = F.softmax(p, dim=-1)
    pred_topk = torch.topk(p, k=top_k, dim=1)[1]
    contains_label = pred_topk.eq(torch.tensor(label_repeat).unsqueeze(1)).any(dim=1)
    selected_p = p[contains_label] #
    if selected_p.shape[0] == 0:
        return torch.tensor([0]).cuda()
    return -torch.mean(torch.sum(selected_p * torch.log(selected_p + 1e-5), 1))

    # return -torch.mean(torch.sum(selected_p * torch.log(selected_p+1e-5), 1))
def entropy_local_topk_distilled(local_out, local_out_expert, label, num_of_local_feature, top_k=3):
    """
    Extract non-Top-K regions and calculate entropy.
    """
    label_repeat = label.repeat_interleave(num_of_local_feature)

    local_out_expert = F.softmax(local_out_expert, dim=-1)
    local_out = F.softmax(local_out, dim=-1)

    pred_topk = torch.topk(local_out_expert, k=top_k, dim=1)[1]
    contains_label = pred_topk.eq(torch.tensor(label_repeat).unsqueeze(1)).any(dim=1)

    selected_p = local_out[contains_label]
    if selected_p.shape[0] == 0:
        return torch.tensor([0]).cuda()
    return -torch.mean(torch.sum(selected_p * torch.log(selected_p + 1e-5), 1))


def cossine_embedding_loss(input, domain_label, label):
    cosemb = torch.nn.CosineEmbeddingLoss()
    loss = 0
    target_label = domain_label.int()*2 - 1
    # for cls_id, cls in enumerate(classnames):
    #     cls_specific_index = (label == cls_id)
        # if torch.all(cls_specific_index == False):
        #     pass
        # else :
        #     if input[cls_specific_index].dim() == 1:
        #         target = domain_label[cls_specific_index] * domain_label[cls_specific_index]
        #         loss += cosemb(input[cls_specific_index].unsqueeze(0), input[cls_specific_index].unsqueeze(0), target)
        #     else :
    for idx in range(input.size(0)):
        # cls_bool = (label == label[idx])
        domain_bool = (target_label == target_label[idx])
        # y = (cls_bool & domain_bool).int()*2 - 1
        label_bool = torch.isin(label, label[idx])
        y = (label_bool & domain_bool).int()*2 - 1
        # print(input.shape)
        # print("hello",y[label_bool].shape, input[idx].shape, input[label_bool].shape)
        # print(label_bool)
        loss += cosemb(input[idx].unsqueeze(0), input[label_bool], y[label_bool])

    return loss / input.size(0)

class Entropy(nn.Module):
    def __init__(self, is_activation:str="softmax2d", eps:float = 1e-7): #FIXME
        super().__init__()
        self.is_activation = is_activation
        # self.class_num = class_num
        if self.is_activation == "softmax2d":
            self.activation = torch.nn.Softmax(dim=1)
        self.eps = eps
    def forward(self, x)->torch.tensor:
        if self.is_activation:
            x = self.activation(x)                
        # t = 1 / self.class_num
        entropy = torch.sum(x * torch.log(x + 1e-5), dim=1)
  
        return -torch.mean(entropy)
        # for_loss = torch.sum((t * torch.log( x + self.eps)), 1)
        # for_loss = - (torch.sum(for_loss) /len(for_loss))
        # return for_loss
class SoftNearestNeighborsLoss(nn.Module):
    def __init__(self, temperature=0.1, distance_type='L2', mahalanobis_cov=None):

        super().__init__()
        self.temperature = temperature
        self.distance_type = distance_type
        self.mahalanobis_cov = mahalanobis_cov
    
    def forward(self, candidates, labels, labels_repeat = None):
        if labels_repeat is not None:
            labels = labels.repeat_interleave(labels_repeat)

        if len(candidates) != len(labels):
            raise ValueError(f"There are {len(candidates)} candidates, but only {(len(labels))} labels")
        device = candidates.device
        b, embed_dim = candidates.shape
        scale = embed_dim**-0.5 
        
        mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).to(device).float()
        mask.fill_diagonal_(0)

        if self.distance_type == 'L2':
            distance_matrix = torch.cdist(candidates, candidates, p=2) ** 2
        
        elif self.distance_type == 'cosine':
            normalized_candidates = nn.functional.normalize(candidates, p=2, dim=1)
            cosine_similarity = torch.mm(normalized_candidates, normalized_candidates.T)
            distance_matrix = 1 - cosine_similarity  #
        
        elif self.distance_type == 'mahalanobis':
            # 
            if self.mahalanobis_cov is None:
                cov_inv = torch.eye(embed_dim, device=device)  
            else:
                cov_inv = torch.inverse(self.mahalanobis_cov).to(device)
            
            distance_matrix = torch.cdist(candidates @ cov_inv, candidates, p=2) ** 2

        elif self.distance_type == 'L1':
            
            distance_matrix = torch.cdist(candidates, candidates, p=1)
        
        else:
            raise ValueError(f"Unsupported distance type: {self.distance_type}")

        
        exp_distance_matrix = torch.exp(-distance_matrix * scale / self.temperature)
        
        numerators = (exp_distance_matrix * mask).sum(dim=1)
        denominators = exp_distance_matrix.sum(dim=1)

        indices = numerators.nonzero()
        numerators = numerators[indices]
        denominators = denominators[indices]

        r = torch.log(numerators / denominators)
        loss = -r.mean()
        return loss
