import random
import os
import numpy as np
import math
from collections import Counter
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import scipy.sparse as sp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import get_laplacian
from deeprobust.graph.utils import sparse_mx_to_torch_sparse_tensor
from torch_sparse import SparseTensor
from torch_geometric.utils import negative_sampling, add_self_loops, train_test_split_edges
from torch_geometric.utils import to_undirected
from torch_geometric.utils import to_networkx

def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def init_params(module):
    if isinstance(module, nn.Linear):
        stdv = 1.0 / math.sqrt(module.weight.size(1))
        module.weight.data.uniform_(-stdv, stdv)
        if module.bias is not None:
            module.bias.data.uniform_(-stdv, stdv)
    if isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)


def normalize_features(mx):
     rowsum = mx.sum(1)
     r_inv = torch.pow(rowsum, -1)
     r_inv[torch.isinf(r_inv)] = 0.
     r_mat_inv = torch.diag(r_inv)
     mx = r_mat_inv @ mx
     return mx


def normalize_adj(mx):
    """Normalize sparse adjacency matrix,
    A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
    """
    if type(mx) is not sp.lil.lil_matrix:
        mx = mx.tolil()
    mx = mx + sp.eye(mx.shape[0])
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1/2).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    mx = mx.dot(r_mat_inv)
    
    return mx


def normalize_adj_to_sparse_tensor(mx):
    mx = normalize_adj(mx)
    mx = sparse_mx_to_torch_sparse_tensor(mx)
    sparsetensor = SparseTensor(row=mx._indices()[0], col=mx._indices()[1], value=mx._values(), sparse_sizes=mx.size()).cuda()
    return sparsetensor


def get_syn_eigen(real_eigenvals, real_eigenvecs, eigen_k, ratio, step=1):
    k1 = math.ceil(eigen_k * ratio)
    k2 = eigen_k - k1
    print("k1:", k1, ",", "k2:", k2)
    k1_end = (k1 - 1) * step + 1
    eigen_sum = real_eigenvals.shape[0]
    k2_end = eigen_sum - (k2 - 1) * step - 1
    k1_list = range(0, k1_end, step)
    k2_list = range(k2_end, eigen_sum, step)
    eigenvals = torch.cat(
        [real_eigenvals[k1_list], real_eigenvals[k2_list]]
    )
    eigenvecs = torch.cat(
        [real_eigenvecs[:, k1_list], real_eigenvecs[:, k2_list]], dim=1,
    )
    
    return eigenvals, eigenvecs


def get_subspace_embed(eigenvecs, x):
    x_trans = eigenvecs.T @ x  # kd
    u_unsqueeze = (eigenvecs.T).unsqueeze(2) # kn1
    x_trans_unsqueeze = x_trans.unsqueeze(1) # k1d
    sub_embed = torch.bmm(u_unsqueeze, x_trans_unsqueeze)  # kn1 @ k1d = knd
    return x_trans, sub_embed


def get_subspace_covariance_matrix(eigenvecs, x):
    x_trans = eigenvecs.T @ x  # kd
    x_trans = F.normalize(input=x_trans, p=2, dim=1)
    x_trans_unsqueeze = x_trans.unsqueeze(1)  # k1d
    co_matrix = torch.bmm(x_trans_unsqueeze.permute(0, 2, 1), x_trans_unsqueeze)  # kd1 @ k1d = kdd
    return co_matrix

  
def get_embed_sum(eigenvals, eigenvecs, x):
    x_trans = eigenvecs.T @ x  # kd
    x_trans = torch.diag(1 - eigenvals) @ x_trans # kd
    embed_sum = eigenvecs @ x_trans # nk @ kd = nd
    return embed_sum


def get_embed_mean(embed_sum, label):
    class_matrix = F.one_hot(label).float()  # nc
    class_matrix = class_matrix.T  # cn
    embed_sum = class_matrix @ embed_sum  # cd
    mean_weight = (1 / class_matrix.sum(1)).unsqueeze(-1)  # c1
    embed_mean = mean_weight * embed_sum
    embed_mean = F.normalize(input=embed_mean, p=2, dim=1)
    return embed_mean


def get_train_lcc(idx_lcc, idx_train, y_full, num_nodes, num_classes):
    idx_train_lcc = list(set(idx_train).intersection(set(idx_lcc)))
    y_full = y_full.cpu().numpy()
    if len(idx_lcc) == num_nodes:
        idx_map = idx_train
    else:
        y_train = y_full[idx_train]
        y_train_lcc = y_full[idx_train_lcc]

        y_lcc_idx = list((set(range(num_nodes)) - set(idx_train)).intersection(set(idx_lcc)))
        y_lcc_ = y_full[y_lcc_idx]
        counter_train = Counter(y_train)
        counter_train_lcc = Counter(y_train_lcc)
        idx = np.arange(len(y_lcc_))
        for c in range(num_classes):
            num_c = counter_train[c] - counter_train_lcc[c]
            if num_c > 0:
                idx_c = list(idx[y_lcc_ == c])
                idx_c = np.array(y_lcc_idx)[idx_c]
                idx_train_lcc += list(np.random.permutation(idx_c)[:num_c])
        idx_map = [idx_lcc.index(i) for i in idx_train_lcc]
                        
    return idx_train_lcc, idx_map



def train_test_split_edges_direct(data, val_ratio: float = 0.05, test_ratio: float = 0.1):

    assert 'batch' not in data  # No batch-mode.

    num_nodes = data.num_nodes
    row, col = data.edge_index
    edge_attr = data.edge_attr
    data.edge_index = data.edge_attr = None

    # Return upper triangular portion.
    mask = row < col
    row, col = row[mask], col[mask]

    if edge_attr is not None:
        edge_attr = edge_attr[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]
    if edge_attr is not None:
        edge_attr = edge_attr[perm]

    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    if edge_attr is not None:
        data.val_pos_edge_attr = edge_attr[:n_v]

    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)
    if edge_attr is not None:
        data.test_pos_edge_attr = edge_attr[n_v:n_v + n_t]

    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)

    # Negative edges.
    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
    neg_adj_mask[row, col] = 0

    neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t()
    perm = torch.randperm(neg_row.size(0))[:n_v + n_t]
    neg_row, neg_col = neg_row[perm], neg_col[perm]

    neg_adj_mask[neg_row, neg_col] = 0
    data.train_neg_adj_mask = neg_adj_mask

    row, col = neg_row[:n_v], neg_col[:n_v]
    data.val_neg_edge_index = torch.stack([row, col], dim=0)

    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
    data.test_neg_edge_index = torch.stack([row, col], dim=0)

    return data



def do_edge_split_direct(dataset, val_ratio=0.05, test_ratio=0.1):
    data = dataset.clone()
    random.seed(234)
    torch.manual_seed(234)


    num_nodes = data.num_nodes

    row, col = data.edge_index

    # Return upper triangular portion.
    mask = row < col
    row, col = row[mask], col[mask]
    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))
    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]
    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)
    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)
    # Negative edges (cannot guarantee (i,j) and (j,i) won't both appear)
    neg_edge_index = negative_sampling(data.edge_index, num_nodes=num_nodes, num_neg_samples=row.size(0))
    data.val_neg_edge_index = neg_edge_index[:, :n_v]
    data.test_neg_edge_index = neg_edge_index[:, n_v:n_v + n_t]
    data.train_neg_edge_index = neg_edge_index[:, n_v + n_t:]

    split_edge = {'train': {}, 'valid': {}, 'test': {}}
    split_edge['train']['edge'] = data.train_pos_edge_index.t()
    split_edge['train']['edge_neg'] = data.train_neg_edge_index.t()
    split_edge['valid']['edge'] = data.val_pos_edge_index.t()
    split_edge['valid']['edge_neg'] = data.val_neg_edge_index.t()
    split_edge['test']['edge'] = data.test_pos_edge_index.t()
    split_edge['test']['edge_neg'] = data.test_neg_edge_index.t()


    # print(split_edge['train']['edge'].shape)
    # print(split_edge['train']['edge_neg'].shape)
    # print(split_edge['valid']['edge'].shape)
    # print(split_edge['valid']['edge_neg'].shape)
    # print(split_edge['test']['edge'].shape)
    # print(split_edge['test']['edge_neg'].shape)
    # dwa
    return split_edge



def edgemask_um(split_edge, device, num_nodes):
    if isinstance(split_edge, torch.Tensor):
        edge_index = split_edge
    else:
        edge_index = split_edge['train']['edge']
    num_edge = len(edge_index)
    index = np.arange(num_edge)
    np.random.shuffle(index)
    mask_ratio = 0
    mask_num = int(num_edge * mask_ratio)
    pre_index = torch.from_numpy(index[0:-mask_num])
    mask_index = torch.from_numpy(index[-mask_num:])
    edge_index_train = edge_index[pre_index].t()
    edge_index_mask = edge_index[mask_index].t()
    edge_index = to_undirected(edge_index_train)
    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
    adj = SparseTensor.from_edge_index(edge_index).t()
    return adj, edge_index, edge_index_mask.to(device)


def sort_by_labels(adj_matrix, label_matrix):
    if label_matrix.ndim > 1:
        labels = np.argmax(label_matrix, axis=1)
    else:
        labels = label_matrix
    sorted_indices = np.argsort(labels)
    sorted_adj_matrix = adj_matrix[sorted_indices][:, sorted_indices]
    return sorted_adj_matrix, sorted_indices

from sklearn.metrics.pairwise import cosine_similarity
def sort_by_soft_labels(adj_matrix, label_matrix):
    similarity_matrix = cosine_similarity(label_matrix)
    sorted_indices = np.argsort(-np.sum(similarity_matrix, axis=1))
    sorted_adj_matrix = adj_matrix[sorted_indices][:, sorted_indices]
    return sorted_adj_matrix, sorted_indices












def row_normalize(edge_index, edge_weight, num_nodes):
    from torch_geometric.utils import degree
    deg = degree(edge_index[0], num_nodes, dtype=torch.float)
    deg_inv = 1.0 / deg
    deg_inv[deg_inv == float('inf')] = 0
    row_norm_edge_weight = edge_weight * deg_inv[edge_index[0]]
    return edge_index, row_norm_edge_weight


def build_adjacency_matrix(data):
    num_nodes = data.x.shape[0]
    edge_index = data.edge_index


    adj_matrix = np.eye(num_nodes, dtype=float)


    src, dst = edge_index

    if data.edge_weight is None:
        adj_matrix[src, dst] = 1
    else:
        adj_matrix[src, dst] = data.edge_weight.numpy()

    return adj_matrix




def generate_condensed_z_y(data, P2):
    P = P2.detach()

    P_one_hot = torch.zeros_like(P)
    P_one_hot[torch.arange(P.shape[0]), P.argmax(dim=1)] = 1.0

    train_labels = data.y[data.train_mask]
    one_hot_train_labels = F.one_hot(train_labels, num_classes=data.num_classes).float().to(P.device)

    one_hot_labels = torch.zeros(data.num_nodes, data.num_classes).to(P.device)
    one_hot_labels[data.train_mask] = one_hot_train_labels
    #one_hot_labels = one_hot_train_labels

    s_emb_label = torch.mm(P_one_hot.t(), one_hot_labels)
    s_emb_label = F.normalize(s_emb_label.clamp(min=0), p=1, dim=1)

    # s_emb_init = B.detach()
    # P = P2.detach()
    #
    # _, max_indices = torch.max(P, dim=1)
    # result = torch.zeros_like(P)
    # result[torch.arange(P.shape[0]), max_indices] = 1.
    # P = result
    #
    # train_labels = data.y[data.train_mask]
    # one_hot_train_labels = F.one_hot(train_labels, data.num_classes).float().cuda()
    # one_hot_labels = torch.zeros(data.num_nodes, data.num_classes).cuda()  # 创建全零矩阵
    # one_hot_labels[data.train_mask] = one_hot_train_labels
    # s_emb_label = torch.mm(P.t(), one_hot_labels)
    #
    # s_emb_label = F.normalize(s_emb_label, p=1, dim=1)
    # s_emb_label[s_emb_label < 0.] = 0.
    # s_emb_label = F.normalize(s_emb_label, p=1, dim=1)

    return s_emb_label



