import torch
import torch.nn as nn
import torch.nn.functional as F


def compute_similarity_matrix(matrix1, matrix2):
    norm1 = torch.norm(matrix1, p=2, dim=1, keepdim=True)
    norm2 = torch.norm(matrix2, p=2, dim=1, keepdim=True)

    normalized_matrix1 = matrix1 / norm1
    normalized_matrix2 = matrix2 / norm2

    similarity_matrix = torch.matmul(normalized_matrix1, normalized_matrix2.T)  # 矩阵乘法
    return similarity_matrix

def sim(a, b):
    similarity_matrix = compute_similarity_matrix(a, b)
    diagonal_elements = torch.diag(similarity_matrix)
    return diagonal_elements

def sim_loss(a, b):
    loss_tensor = sim(a, b)
    return torch.sum(loss_tensor) / loss_tensor.shape[0]

def my_loss101(xs, xrrs, xrs, temperature_f, l1_factor):
    def f1(a, b):
        similarity_matrix = compute_similarity_matrix(a, b)
        similarity_matrix = torch.exp(similarity_matrix / temperature_f)
        diagonal_elements = torch.diag(similarity_matrix)
        diff_matrix = (torch.sum(similarity_matrix, dim=0) - diagonal_elements) / 2
        return diagonal_elements, diff_matrix

    x_e1, x_m1 = f1(xs, xrs)
    x_e2, x_m2 = f1(xs, xrrs)
    x_e3, x_m3 = f1(xrs, xrrs)

    loss_tensor = torch.log(l1_factor + x_e1 + x_e2 + x_e3)

    return - torch.sum(loss_tensor) / loss_tensor.shape[0]

def feature_loss(fv, cf, temperature_f, l2_factor):
    def f1(a, b):
        similarity_matrix = compute_similarity_matrix(a, b)
        similarity_matrix = torch.exp(similarity_matrix / temperature_f)
        diagonal_elements = torch.diag(similarity_matrix)
        diff_matrix = (torch.sum(similarity_matrix, dim=0) - diagonal_elements) / 2
        return diagonal_elements, diff_matrix

    fvcf_e, fvcf_m = f1(fv, cf)

    loss_tensor = torch.log(l2_factor + (fvcf_e) / (fvcf_m))

    return - torch.sum(loss_tensor) / loss_tensor.shape[0]

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=1.0):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, x_q, x_k, mask_pos=None):
        x_q = nn.functional.normalize(x_q)
        x_k = nn.functional.normalize(x_k)
        N = x_q.shape[0]
        if mask_pos is None:
            mask_pos = torch.eye(N).cuda()
        similarity = torch.div(torch.matmul(x_q, x_k.T), self.temperature)
        similarity = -torch.log(torch.softmax(similarity, dim=1))
        nll_loss = similarity * mask_pos / mask_pos.sum(dim=1, keepdim=True)
        loss = nll_loss.mean()
        return loss


@torch.no_grad()
def kernel_affinity(z, temperature=0.1, step: int = 5):
    z = nn.functional.normalize(z, dim=0)
    G = (2 - 2 * (z @ z.t())).clamp(min=0.)
    G = torch.exp(-G / temperature)
    G = G / G.sum(dim=1, keepdim=True)

    G = torch.matrix_power(G, step)
    alpha = 0.5
    G = torch.eye(G.shape[0]).cuda() * alpha + G * (1 - alpha)
    return G

@torch.no_grad()
def kernel_affinity2(z, temperature=0.1, step: int = 5):
    z = nn.functional.normalize(z, dim=0)
    G = (2 - 2 * (z @ z.t())).clamp(min=0.)
    G = torch.exp(-G / temperature)
    G = G / G.sum(dim=1, keepdim=True)
    G = torch.matrix_power(G, step)
    return G


def loss1(client, xrs, zs, hs, h, ws, pt, labels, centers):
    loss_list1, loss_list2, loss_list3 = [], [], []
    for v in client.have_view:
        # loss_list1.append(F.mse_loss(self.xs[v], xrs[v]) + F.mse_loss(zs[v], zrs[v]))
        loss_list1.append(F.mse_loss(client.xs[v], xrs[v]))
        loss_list2.append(feature_loss(zs[v], h, client.args.temperature_f, client.args.l2_factor))

    # for v in give_view:
    #     loss_list3.append(-math.exp(-metric.mutual_information(zs[v], h)))
    #     loss_list3.append(F.mse_loss(ws[v], torch.tensor([0.5]).to(self.device)))
    #     mp = self.kernel_affinity(zs[v])
    #     l_inter = self.cl(h, zs[v], mp)
    #     l_intra = self.cl(o_zs[v], zs[v], mp)
    #     loss_list3.append(l_inter + l_intra)

    loss1 = sum(loss_list1) / len(loss_list1) if len(loss_list1) > 0 else 0
    loss2 = sum(loss_list2) / len(loss_list2) if len(loss_list2) > 0 else 0
    loss3 = sum(loss_list3) / len(loss_list3) if len(loss_list3) > 0 else 0
    loss = loss1 + loss2 + loss3

    return loss


def loss2(client, xrs, zs, hs, h, ws, pt, labels, centers):
    loss_list = []

    cos_sim = F.cosine_similarity(pt.unsqueeze(1), pt.unsqueeze(0), dim=2)
    w_sim = F.normalize(cos_sim)
    afft = kernel_affinity2(h)
    afft = afft * w_sim
    afft.fill_diagonal_(0)
    loss_list.append(- afft.sum())

    loss = sum(loss_list) / len(loss_list)
    return loss


def loss3(client, xrs, zs, hs, h, ws, pt, labels, centers):
    loss_list1, loss_list2, loss_list3 = [], [], []
    for v in client.have_view:
        # loss_list1.append(F.mse_loss(self.xs[v], xrs[v]) + F.mse_loss(zs[v], zrs[v]))
        loss_list1.append(F.mse_loss(client.xs[v], xrs[v]))
        loss_list2.append(feature_loss(zs[v], h, client.args.temperature_f, client.args.l2_factor))

    cos_sim = F.cosine_similarity(pt.unsqueeze(1), pt.unsqueeze(0), dim=2)
    w_sim = F.normalize(cos_sim)
    afft = kernel_affinity2(h)
    afft = afft * w_sim
    afft.fill_diagonal_(0)
    loss_list3.append(- afft.sum())

    loss1 = sum(loss_list1) / len(loss_list1) if len(loss_list1) > 0 else 0
    loss2 = sum(loss_list2) / len(loss_list2) if len(loss_list2) > 0 else 0
    loss3 = sum(loss_list3) / len(loss_list3) if len(loss_list3) > 0 else 0
    loss = loss1 + loss2 + loss3

    return loss