import torch
import numpy as np
import utils
from torch import nn
from torch.nn.functional import normalize
from utils import *

def compute_mmd(x, y, kernel='rbf'):
    """Compute MMD between two sets of embeddings x and y."""
    def gaussian_kernel(a, b, sigma=1.0):
        gamma = 1.0 / (2 * sigma ** 2)
        a_norm = a.pow(2).sum(1).reshape(-1, 1)
        b_norm = b.pow(2).sum(1).reshape(1, -1)
        dist = a_norm + b_norm - 2.0 * torch.matmul(a, b.t())
        return torch.exp(-gamma * dist)

    if kernel == 'rbf':
        Kxx = gaussian_kernel(x, x)
        Kyy = gaussian_kernel(y, y)
        Kxy = gaussian_kernel(x, y)
        mmd = Kxx.mean() + Kyy.mean() - 2 * Kxy.mean()
        return mmd
    else:
        raise NotImplementedError(f"Unsupported kernel type: {kernel}")


class AdaGAE(torch.nn.Module):
    def __init__(self, layer_dims, z_pass_linear=False):
        super().__init__()
        self.w1 = self.get_weight_initial([layer_dims[0], layer_dims[1]])
        self.w2 = self.get_weight_initial([layer_dims[1], layer_dims[2]])

        if z_pass_linear:
            self.z_pass_linear1 = torch.nn.Linear(layer_dims[2], layer_dims[1])
            self.z_pass_linear2 = torch.nn.Linear(layer_dims[1], layer_dims[2])

    def get_weight_initial(self, shape):
        bound = np.sqrt(6.0 / (shape[0] + shape[1]))
        ini = torch.rand(shape) * 2 * bound - bound
        return torch.nn.Parameter(ini, requires_grad=True)

    def forward(self, xi, Laplacian):
        # 编码
        embedding = Laplacian.mm(xi.matmul(self.w1))
        embedding = torch.nn.functional.relu(embedding)
        embedding = Laplacian.mm(embedding.matmul(self.w2))

        # 重构
        distances = utils.distance(embedding.t(), embedding.t())
        softmax = torch.nn.Softmax(dim=1)
        recons_w = softmax(-distances)
        return embedding, recons_w + 10**-10

    def cal_loss(self, raw_weights, recons, weights, embeding, lam):
        re_loss = raw_weights * torch.log(raw_weights / recons + 10**-10)
        re_loss = re_loss.sum(dim=1)
        re_loss = re_loss.mean()

        size = embeding.shape[0]
        degree = weights.sum(dim=1)
        L = torch.diag(degree) - weights
        tr_loss = torch.trace(embeding.t().matmul(L).matmul(embeding)) / size
        return re_loss, tr_loss


class AdaGAEMV(torch.nn.Module):
    def __init__(self, X, layers, device):
        super().__init__()  #
        layers_list = [[x.shape[1]] + layers for x in X]
        self.gae_list = torch.nn.ModuleList([AdaGAE(layer).to(device) for layer in layers_list])

        self.num_views = len(X)
        self.device = device

        # 添加可学习的自适应视图权重参数
        self.view_weights = torch.nn.Parameter(torch.ones(self.num_views) / self.num_views, requires_grad=True)

    def forward(self, X, lapacian_mv):
        embedding_list = []
        recons_w_list = []
        for i in range(self.num_views):
            embedding, recons_w = self.gae_list[i](X[i], lapacian_mv[i])
            embedding_list.append(embedding)
            recons_w_list.append(recons_w)

        norm_weights = torch.nn.functional.softmax(self.view_weights, dim=0)
        fused_embedding = sum(w * emb for w, emb in zip(norm_weights, embedding_list))

        return embedding_list, recons_w_list, fused_embedding, norm_weights

    def cal_loss(self, raw_weights_mv, recons_w_list, weights_mv, embedding_list, fused_embedding, lam, alpha_mmd,
                 beta_cluster, n_clusters):
        re_loss_list = []
        tr_loss_list = []
        mmd_loss_list = []

        norm_weights = torch.nn.functional.softmax(self.view_weights, dim=0)

        for i in range(self.num_views):
            re_loss, tr_loss = self.gae_list[i].cal_loss(
                raw_weights_mv[i],
                recons_w_list[i],
                weights_mv[i],
                embedding_list[i],
                lam,
            )
            re_loss_list.append(re_loss)
            tr_loss_list.append(tr_loss)

            # MMD 损失
            mmd_loss = compute_mmd(embedding_list[i], fused_embedding)
            mmd_loss_list.append(mmd_loss)

        re_loss = sum(w * l for w, l in zip(norm_weights, re_loss_list))
        tr_loss = sum(w * l for w, l in zip(norm_weights, tr_loss_list))
        mmd_loss = sum(mmd_loss_list) / self.num_views

        # 🔶 聚类损失部分（只对融合的嵌入）
        with torch.no_grad():
            Q = soft_kmeans_assign(fused_embedding, n_clusters)  # N x K
            P = target_distribution(Q)  # N x K
        cluster_loss = F.kl_div(Q.log(), P, reduction='batchmean')

        # 总损失
        total_loss = re_loss + lam * tr_loss + alpha_mmd * mmd_loss + beta_cluster * cluster_loss
        return total_loss, re_loss, tr_loss, mmd_loss, cluster_loss


