import numpy as np
import sys
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import utils
import scipy.sparse as sparse
from scipy.sparse import issparse
from sklearn.preprocessing import normalize
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import argparse
import random
from tqdm import tqdm
import os
from dataloader import load_data
import torch.nn.functional as F  
from optuna.visualization import plot_optimization_history, plot_param_importances, plot_contour
from sklearn.cluster import k_means
import warnings
from scipy.sparse import SparseEfficiencyWarning
from scipy.stats import entropy


warnings.filterwarnings('ignore', category=SparseEfficiencyWarning)


os.environ["CUDA_VISIBLE_DEVICES"] = "0，1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MLP(nn.Module):
    def __init__(self, input_dims, hid_dims, out_dims, kaiming_init=False):
        super(MLP, self).__init__()
        self.input_dims = input_dims
        self.hid_dims = hid_dims
        self.output_dims = out_dims
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_dims, self.hid_dims[0]))
        self.layers.append(nn.ReLU())
        for i in range(len(hid_dims) - 1):
            self.layers.append(nn.Linear(self.hid_dims[i], self.hid_dims[i + 1]))
            self.layers.append(nn.ReLU())

        self.out_layer = nn.Linear(self.hid_dims[-1], self.output_dims)
        if kaiming_init:
            self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                init.kaiming_uniform_(layer.weight)
                init.zeros_(layer.bias)
        init.xavier_uniform_(self.out_layer.weight)
        init.zeros_(self.out_layer.bias)

    def forward(self, x):
        h = x
        for i, layer in enumerate(self.layers):
            h = layer(h)
        h = self.out_layer(h)
        h = torch.tanh_(h)
        return h


class AdaptiveSoftThreshold(nn.Module):
    def __init__(self, dim):
        super(AdaptiveSoftThreshold, self).__init__()
        self.dim = dim
        self.register_parameter("bias", nn.Parameter(torch.from_numpy(np.zeros(shape=[self.dim])).float()))

    def forward(self, c):
        return torch.sign(c) * torch.relu(torch.abs(c) - self.bias)


class ViewConsistencyModule(nn.Module):
    def __init__(self, feature_dim):
        super(ViewConsistencyModule, self).__init__()
        self.projector = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

    def forward(self, view_embeddings):
        num_views = len(view_embeddings)
        consistency_loss = 0.0

        for i in range(num_views):
            for j in range(num_views):
                if i != j:
                    pred_j = self.projector(view_embeddings[i])
                    pred_j = F.normalize(pred_j, dim=1)
                    target_j = F.normalize(view_embeddings[j], dim=1)
                    similarity = torch.sum(pred_j * target_j, dim=1).mean()
                    consistency_loss += 1.0 - similarity

        return consistency_loss / (num_views * (num_views - 1)) if num_views > 1 else consistency_loss


class Mango(nn.Module):
    def __init__(self, input_dims, hid_dims, out_dims, num_views=6):
        super(Mango, self).__init__()
        self.input_dims = input_dims
        self.hid_dims = hid_dims
        self.out_dims = out_dims
        self.shrink = 1.0 / out_dims
        self.num_views = num_views
        self.view_consistency = ViewConsistencyModule(feature_dim=out_dims)
        self.nets = nn.ModuleList([
            MLP(input_dims=self.input_dims[v], hid_dims=self.hid_dims, out_dims=self.out_dims, kaiming_init=True)
            for v in range(self.num_views)
        ])
        self.thres = AdaptiveSoftThreshold(1)

    def embedding(self, *query_views):
        if len(query_views) != self.num_views:
            raise ValueError(f"Expected {self.num_views} views, got {len(query_views)}")
        return tuple(self.nets[i](queries) for i, queries in enumerate(query_views))

    def get_coeff(self, q_emb, k_emb):
        c = self.thres(q_emb.mm(k_emb.t()))
        return self.shrink * c


def get_sparse_rep(mango, data_views, batch_size=10, non_zeros=200, beta_weights=None):
    if not isinstance(data_views, (list, tuple)):
        raise ValueError("data_views must be a list or tuple of tensors")

    num_views = len(data_views)
    N = data_views[0].shape[0]
    non_zeros = min(N, non_zeros)
    C = torch.empty([batch_size, N])

    if (N % batch_size != 0):
        raise Exception("batch_size should be a factor of dataset size.")

    val = []
    indicies = []
    with torch.no_grad():
        mango.eval()
        for i in range(N // batch_size):
            batch_views = [view[i * batch_size:(i + 1) * batch_size].cuda() for view in data_views]
            q_embs = mango.embedding(*batch_views)

            for j in range(N // batch_size):
                chunk_views = [view[j * batch_size: (j + 1) * batch_size].cuda() for view in data_views]
                k_embs = mango.embedding(*chunk_views)
                temp_coeffs = [mango.get_coeff(q_embs[v], k_embs[v]) for v in range(num_views)]

                if beta_weights is not None:
                    weighted_sum = sum(beta_weights[v] * temp_coeffs[v].cpu() for v in range(num_views))
                else:
                    weighted_sum = sum(temp.cpu() for temp in temp_coeffs) / num_views

                weighted_sum = torch.nan_to_num(weighted_sum, nan=0.0, posinf=1.0, neginf=-1.0)
                C[:, j * batch_size:(j + 1) * batch_size] = weighted_sum

            rows = list(range(batch_size))
            cols = [j + i * batch_size for j in rows]
            C[rows, cols] = 0.0
            C = torch.nan_to_num(C, nan=0.0, posinf=1.0, neginf=-1.0)

            _, index = torch.topk(torch.abs(C), dim=1, k=non_zeros)
            val.append(C.gather(1, index).reshape([-1]).cpu().data.numpy())
            index = index.reshape([-1]).cpu().data.numpy()
            indicies.append(index)

    val = np.concatenate(val, axis=0)
    indicies = np.concatenate(indicies, axis=0)
    indptr = [non_zeros * i for i in range(N + 1)]
    val = np.nan_to_num(val, nan=0.0, posinf=1.0, neginf=-1.0)
    C_sparse = sparse.csr_matrix((val, indicies, indptr), shape=[N, N])
    return C_sparse


def get_knn_graph(data, k):
    """构建K近邻图，返回邻居索引"""
    # 计算欧氏距离矩阵
    dot_product = torch.mm(data, data.t())
    squared_norm = torch.diag(dot_product).unsqueeze(1)
    distance_matrix = squared_norm - 2.0 * dot_product + squared_norm.t()

    # 获取每个样本k+1个最近邻索引（包括自身）
    _, indices = torch.topk(-distance_matrix, k=k + 1, dim=1, largest=False)

    # 返回邻居索引（不包括自身）
    return indices[:, 1:]


def regularizer(c, lmbd=1.0):
    return lmbd * torch.abs(c).sum() + (1.0 - lmbd) / 2.0 * torch.pow(c, 2).sum()





def get_knn_Aff(C_sparse_normalized, k=3):
    """构建KNN亲和矩阵（GPU优化版本）"""
    X_dense = C_sparse_normalized.toarray()
    X_tensor = torch.tensor(X_dense, dtype=torch.float32).cuda()
    n = X_tensor.shape[0]

    # 计算欧氏距离矩阵
    a2 = torch.sum(X_tensor ** 2, dim=1, keepdim=True)
    b2 = torch.sum(X_tensor ** 2, dim=1)
    ab = X_tensor @ X_tensor.t()
    dist_matrix = a2 - 2 * ab + b2.view(1, -1)

    # 获取每行最小的k+1个值的索引（包括自己）
    _, indices = torch.topk(dist_matrix, k=k + 1, dim=1, largest=False)

    # 将结果移回CPU
    indices = indices.cpu().numpy()

    # 构建稀疏矩阵（跳过第一个邻居，即自身）
    rows = np.repeat(np.arange(n), k)
    cols = indices[:, 1:].reshape(-1)  # 跳过第一列（自身）
    data = np.ones(rows.shape[0])

    C_knn = sparse.csr_matrix((data, (rows, cols)), shape=(n, n))

    
    Aff_knn = 0.5 * (C_knn + C_knn.T)
   
    return Aff_knn


def robust_intra_view_contrastive_loss(embeddings, temperature=0.3, t=1, alpha=1.2, sigma=0.3, beta=0.6, use_random_walk=True):
    N = embeddings.size(0)
    embeddings = nn.functional.normalize(embeddings, dim=1)

    # 计算相似度矩阵
    similarity_matrix = torch.matmul(embeddings, embeddings.t()) / temperature

    # 对角线掩码（自身为正例）
    pos_mask = torch.eye(N, device=embeddings.device)
    
    if use_random_walk:
        # 使用随机游走修正目标分布
        target_distribution = utils.random_walk_correction(embeddings, t=t, alpha=alpha, sigma=sigma)
        # 随机游走修正后的非自身样本权重
        neg_mask = (1 - pos_mask) * target_distribution
    else:
        # 不使用随机游走，直接使用均匀分布
        neg_mask = 1 - pos_mask

    # 计算正例相似度（对角线元素）
    pos_sim = torch.diagonal(similarity_matrix)

    # 计算加权负例相似度（非对角线元素）
    neg_weight = neg_mask / (neg_mask.sum(dim=1, keepdim=True) + 1e-8)
    weighted_neg_sim = torch.exp(similarity_matrix) * neg_weight

    # Robust InfoNCE: 使用β幂运算降低假负样本的影响
    weighted_neg_sim = torch.pow(weighted_neg_sim.sum(dim=1), beta)

    # 计算最终损失
    loss = -pos_sim + torch.log(torch.exp(pos_sim) + weighted_neg_sim)

    return loss.mean()


def robust_cross_view_contrastive_loss(embeddings1, embeddings2, temperature=0.1, beta=0.6):
    N = embeddings1.size(0)

    # 归一化嵌入
    embeddings1 = nn.functional.normalize(embeddings1, dim=1)
    embeddings2 = nn.functional.normalize(embeddings2, dim=1)

    # 计算跨视图相似度
    similarity = torch.matmul(embeddings1, embeddings2.t()) / temperature

    # 正例掩码（对角线元素）
    pos_mask = torch.eye(N, device=embeddings1.device)
    neg_mask = 1 - pos_mask

    # 计算正例相似度（对角线元素）
    pos_sim = torch.diagonal(similarity)

    # 计算负例相似度（非对角线元素）
    neg_weight = neg_mask / (neg_mask.sum(dim=1, keepdim=True) + 1e-8)
    weighted_neg_sim = torch.exp(similarity) * neg_weight
    weighted_neg_sim = torch.pow(weighted_neg_sim.sum(dim=1), beta)

    # 计算最终损失
    loss = -pos_sim + torch.log(torch.exp(pos_sim) + weighted_neg_sim)

    return loss.mean()

def evaluate(mango, data_views, labels, num_subspaces, spectral_dim, non_zeros=200,
             n_neighbors=10, batch_size=10000, 
              knn_mode='symmetric', beta_weights=None,
             diffusion=True, diffusion_steps=3, diffusion_top_k=30, temperature=0.2):
    # 获取稀疏表示
    C_sparse = get_sparse_rep(mango=mango,
                              data_views=data_views,
                              batch_size=min(batch_size, 20000),
                              non_zeros=non_zeros,
                              beta_weights=beta_weights)

    # 检查稀疏矩阵中是否有NaN值
    if np.isnan(C_sparse.data).any() or np.isinf(C_sparse.data).any():
        C_sparse.data = np.nan_to_num(C_sparse.data, nan=0.0, posinf=1.0, neginf=-1.0)

    # 确保矩阵有非零元素
    if C_sparse.nnz == 0:
      
        noise = sparse.lil_matrix(C_sparse.shape, dtype=np.float32)
        rows = np.random.randint(0, C_sparse.shape[0], size=C_sparse.shape[0])
        cols = np.random.randint(0, C_sparse.shape[1], size=C_sparse.shape[0])
        data = np.random.uniform(0.001, 0.01, size=C_sparse.shape[0])
        for i, j, v in zip(rows, cols, data):
            noise[i, j] = v
        C_sparse = C_sparse + noise.tocsr()

    try:
        # 安全地进行归一化
        C_sparse_normalized = normalize(C_sparse, norm='l2', axis=1).astype(np.float32)
    except ValueError as e:
        print(f"归一化失败: {e}")
        row_sums = np.array(C_sparse.sum(axis=1)).flatten()
        row_sums[row_sums == 0] = 1.0
        row_indices, col_indices = C_sparse.nonzero()
        C_sparse.data = C_sparse.data / row_sums[row_indices]
        C_sparse_normalized = C_sparse.astype(np.float32)


    Aff = get_knn_Aff(C_sparse_normalized, k=n_neighbors)

    if isinstance(Aff, sparse.spmatrix):
        if np.isnan(Aff.data).any() or np.isinf(Aff.data).any():
            Aff.data = np.nan_to_num(Aff.data, nan=0.0, posinf=1.0, neginf=-1.0)

    
    if diffusion:
        Aff_norm = normalize(Aff, norm='l1', axis=1)
        Aff_diff = utils.multi_scale_diffusion(
            Aff_norm,
            max_steps=diffusion_steps,
            top_k=diffusion_top_k,
            temperature=temperature
        )
    else:
        Aff_diff = Aff

    try:
        if sparse.issparse(Aff_diff):
            Aff_diff = 0.5 * (Aff_diff + Aff_diff.T)
            diag_indices = np.arange(Aff_diff.shape[0])
            Aff_diff[diag_indices, diag_indices] += 1e-6

        laplacian = sparse.csgraph.laplacian(Aff_diff, normed=True)
        maxiter = 10000
        tol = 1e-6
        try:
            _, vec = sparse.linalg.eigsh(
                sparse.identity(laplacian.shape[0]) - laplacian,
                k=spectral_dim,
                sigma=None,
                which='LA',
                maxiter=maxiter,
                tol=tol
            )
        except sparse.linalg.ArpackNoConvergence:
            try:
                _, vec = sparse.linalg.eigsh(
                    sparse.identity(laplacian.shape[0]) - laplacian,
                    k=min(spectral_dim, 3),
                    sigma=None,
                    which='LA',
                    maxiter=maxiter,
                    tol=tol
                )
            except sparse.linalg.ArpackNoConvergence as e:
                vec = np.random.randn(laplacian.shape[0], spectral_dim)
        embedding = normalize(vec)

        _, labels_, _ = k_means(embedding, num_subspaces, random_state=42, n_init=10)
        preds = labels_

    except Exception as e:
        print(f"谱聚类失败: {e}")
        data = np.vstack([view.cpu().numpy() for view in data_views])
        _, preds, _ = k_means(data, num_subspaces, random_state=42, n_init=10)

    acc = utils.clustering_accuracy(labels, preds)
    nmi = normalized_mutual_info_score(labels, preds, average_method='geometric')
    ari = adjusted_rand_score(labels, preds)

    return acc, nmi, ari


def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


import sys

def main():
    sys.argv = ['main_train.py']
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="MSRCv1")
    parser.add_argument('--extractor_type', type=str, default='mlp')
    parser.add_argument('--lr', type=float, default=0.0033)
    parser.add_argument('--hid_dims', type=int, default=[1024, 512, 256])
    parser.add_argument('--out_dims', type=int, default=256)
    parser.add_argument('--batch_size', type=int, default=210)
    parser.add_argument('--epoch', type=int, default=400)
    parser.add_argument('--alpha', type=float, default=1000)
    parser.add_argument('--beta', type=float, default=180000)
    parser.add_argument('--gamma', type=float, default=27000)
    parser.add_argument('--spectral_dim', type=int, default=8)
    parser.add_argument('--lmbd', type=float, default=0.3)
    parser.add_argument('--temperature', type=float, default=0.6)
    parser.add_argument('--random_walk_steps', type=int, default=3)

    parser.add_argument('--seed', type=int, default=7)
    args = parser.parse_args()
    best_params = {
        'lr': 0.0033,
        'alpha': 1000,
        'beta': 180000,
        'gamma': 27000,
        'spectral_dim': 8,
    }
    for param, value in best_params.items():
        setattr(args, param, value)
    

    dataset, dims, view, data_size, class_num = load_data(args.dataset)
    num_views = view

    all_samples = [getattr(dataset, f'data{i}') for i in range(1, num_views + 1)]
    full_labels = dataset.y - np.min(dataset.y)

    data_views = []
    for i, sample in enumerate(all_samples):
        if not isinstance(sample, np.ndarray):
            try:
                if issparse(sample):
                    sample = sample.toarray()
                else:
                    sample = np.array(sample)
            except Exception as e:
                raise ValueError(f"数据转换失败: {e}")

        if sample.ndim == 1:
            sample = sample.reshape(1, -1)
            
        try:
            mean = np.mean(sample, axis=0, keepdims=True)
            std = np.std(sample, axis=0, keepdims=True) + 1e-8
            data = (sample - mean) / std
            
            if np.isnan(data).any() or np.isinf(data).any():
                data = np.nan_to_num(data, nan=0.0, posinf=1.0, neginf=-1.0)
                
        except Exception as e:
            raise ValueError(f"数据归一化失败：{e}")
            
        data = torch.from_numpy(data).float()
        data = utils.p_normalize(data).cuda()
        data_views.append(data)

    mango = Mango(dims, args.hid_dims, args.out_dims, num_views=num_views).cuda()
    optimizer = optim.Adam(mango.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch, eta_min=0.0)

    block_size = min(data_size, 20000)
    n_iter_per_epoch = data_size // args.batch_size
    n_step_per_iter = data_size // block_size

    for epoch in range(1, args.epoch + 1):
        randidx = torch.randperm(data_size)
        total_loss = 0

        for i in range(n_iter_per_epoch):
            mango.train()
            batch_views = [data_views[v][randidx[i * args.batch_size: (i + 1) * args.batch_size]].cuda()
                           for v in range(num_views)]
            q_embs = mango.embedding(*batch_views)
            rec_batches = [torch.zeros_like(view).cuda() for view in batch_views]
            reg = torch.zeros([1]).cuda()

            intra_loss = sum(robust_intra_view_contrastive_loss(q_embs[v], args.temperature,t=args.random_walk_steps,)
                             for v in range(num_views)) / num_views

            cross_loss = sum(robust_cross_view_contrastive_loss(q_embs[i], q_embs[j],
                                                                args.temperature)
                             for i in range(num_views)
                             for j in range(i + 1, num_views)) / (num_views * (num_views - 1) // 2)

            for j in range(n_step_per_iter):
                block_views = [data_views[v][j * block_size: (j + 1) * block_size].cuda()
                               for v in range(num_views)]
                k_embs = mango.embedding(*block_views)

                for v in range(num_views):
                    c = mango.get_coeff(q_embs[v], k_embs[v])
                    rec_batches[v] = rec_batches[v] + c.mm(block_views[v])
                    reg = reg + regularizer(c, args.lmbd)

            diag_cs = []
            for v in range(num_views):
                diag_c = mango.thres((q_embs[v] * q_embs[v]).sum(dim=1, keepdim=True)) * mango.shrink
                diag_cs.append(diag_c)
                rec_batches[v] = rec_batches[v] - diag_c * batch_views[v]
                reg = reg - regularizer(diag_c, args.lmbd)

            rec_loss = sum(torch.sum(torch.pow(batch_views[v] - rec_batches[v], 2))
                           for v in range(num_views))
            view_consistency_loss = mango.view_consistency(q_embs)

            loss = (0.5 * args.alpha * rec_loss + reg +
                    args.beta * (intra_loss +0.1*cross_loss) +
                    args.gamma * view_consistency_loss) / (args.batch_size * num_views)

            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(mango.parameters(), 0.001)
            optimizer.step()

        scheduler.step()

        avg_loss = total_loss / n_iter_per_epoch
        if epoch % 500 == 0:
            print(f"Epoch {epoch}/{args.epoch}, Loss: {avg_loss:.6f}")

        if epoch % 50 == 0 and epoch >350:
            acc, nmi, ari = evaluate(
                mango, data_views=data_views, labels=full_labels,
                num_subspaces=class_num, 
                spectral_dim=args.spectral_dim,
               batch_size=block_size,
                temperature=args.temperature
            )

            print(f"Epoch {epoch}, ACC: {acc:.6f}, NMI: {nmi:.6f}, ARI: {ari:.6f}")


if __name__ == "__main__":
    for i in range(10):
        main()
