import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

def get_distance(model1, model2):
    with torch.no_grad():
        model1_flattened = nn.utils.parameters_to_vector(model1.parameters())
        model2_flattened = nn.utils.parameters_to_vector(model2.parameters())
        distance = torch.square(torch.norm(model1_flattened - model2_flattened))
    return distance


def get_distances_from_current_model(current_model, party_models):
    num_updates = len(party_models)
    distances = np.zeros(num_updates)
    for i in range(num_updates):
        distances[i] = Utils.get_distance(current_model, party_models[i])
    return distances

def evaluate(testloader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total

def agg_func(protos):
    """
    Returns the average of the weights.
    """

    for [label, proto_list] in protos.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            protos[label] = proto / len(proto_list)
        else:
            protos[label] = proto_list[0]

    return protos

def feature_extractor(model, x):
    """提取卷积网络每一层的输出"""
    features = []
    hooks = []

    # 为每一层注册forward hook
    def hook(module, input, output):
        features.append(output.view(output.size(0), -1))  # 展平特征图

    for layer in model.children():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            hooks.append(layer.register_forward_hook(hook))

    # 前向传播输入x，触发hook
    with torch.no_grad():
        model(x)

    # 移除hook
    for h in hooks:
        h.remove()

    return features

# CKA implementation
def cka_similarity(X, Y):
    """计算两个特征矩阵之间的CKA相似度"""
    X = X - X.mean(dim=0)  # 去中心化
    Y = Y - Y.mean(dim=0)  # 去中心化

    # 计算Gram矩阵
    XT_Y = X.T @ Y
    frob_norm_XT_Y = torch.norm(XT_Y, p='fro') ** 2  # ||X^T Y||_F^2

    # 计算分母部分的Frobenius范数
    frob_norm_XT_X = torch.norm(X.T @ X, p='fro') ** 2
    frob_norm_YT_Y = torch.norm(Y.T @ Y, p='fro') ** 2

    # 根据公式计算CKA
    cka_value = frob_norm_XT_Y / torch.sqrt(frob_norm_XT_X * frob_norm_YT_Y)
    return cka_value

# IS implementation
def kl_similarity(X, Y):
    """计算两个特征矩阵之间的KL散度相似性"""
    # 对特征进行softmax归一化以表示为概率分布
    P = F.softmax(X, dim=1)
    Q = F.softmax(Y, dim=1)

    P = P + 1e-6
    Q = Q + 1e-6

    # 使用公式计算KL散度
    kl_div = torch.sum(P * torch.log(P / Q), dim=1)

    # 取平均值作为最终的KL散度相似度
    return kl_div.mean()

def compare_models(model1, model2, x):
    features1 = feature_extractor(model1, x)
    features2 = feature_extractor(model2, x)

    cka_similarities = []
    kl_similarities = []
    for f1, f2 in zip(features1, features2):
        kl = kl_similarity(f1, f2)
        cka = cka_similarity(f1, f2)
        kl_similarities.append(kl.item())
        cka_similarities.append(cka_similarity(f1, f2))
    return cka_similarities, kl_similarities


class CombinedLoss(torch.nn.Module):
    def __init__(self, temperature=0.07, alpha=1.0, beta=0.1):
        super(CombinedLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha  # 分类损失权重
        self.beta = beta  # 对比损失权重

    def forward(self, logits, labels, embeddings, domains):
        """
        :param logits: 模型的分类预测输出，大小为 [batch_size, num_classes]
        :param labels: 样本的真实标签，大小为 [batch_size]
        :param embeddings: 样本的特征嵌入，大小为 [batch_size, embedding_dim]
        :param domains: 样本的领域标签，用于构造正负样本对，大小为 [batch_size]
        """

        # 分类损失 (Cross-Entropy Loss)
        classification_loss = F.cross_entropy(logits, labels)

        # 对比损失 (Supervised Contrastive Loss)
        # 获取批次中的标签集合和领域集合
        batch_size = embeddings.shape[0]
        contrastive_loss = 0.0
        num_positive_pairs = 0

        # 遍历批次中每一个样本对
        for i in range(batch_size):
            for j in range(i + 1, batch_size):
                # 如果两个样本具有相同标签且来自不同领域，构成正样本对
                if labels[i] == labels[j] and domains[i] != domains[j]:
                    # 计算正样本对的相似性
                    positive_similarity = F.cosine_similarity(embeddings[i], embeddings[j]) / self.temperature
                    contrastive_loss += -positive_similarity
                    num_positive_pairs += 1
                # 如果两个样本标签不同，则构成负样本对
                elif labels[i] != labels[j]:
                    # 计算负样本对的相似性
                    negative_similarity = F.cosine_similarity(embeddings[i], embeddings[j]) / self.temperature
                    contrastive_loss += F.relu(negative_similarity)  # 保证负样本对距离较远

        # 归一化对比损失
        if num_positive_pairs > 0:
            contrastive_loss /= num_positive_pairs

        # 组合总损失
        total_loss = self.alpha * classification_loss + self.beta * contrastive_loss
        return total_loss


