import torch
import torch.nn as nn
import scipy.sparse as sp
from time import perf_counter
from utils import sparse_eye, sparse_full,sparse_mx_to_torch_sparse_tensor
from einops import rearrange
# from normalization import aug_normalized_adjacency, normalized_adjacency
import pdb
import numpy as np
import torch.nn.functional as F
from sklearn.manifold import TSNE
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from matplotlib import rcParams
import scanpy as sc
rcParams.update({'font.family':'Arial'})


def postive_metric(adj, labels):
    num_nodes = adj.shape[0]
    adj = adj.to_dense()
    numerator = torch.zeros(num_nodes)
    denominator = torch.zeros(num_nodes)
    for i in range(num_nodes):
        # breakpoint()
        node_label = labels[i]
        # neighbors = (adj[i] != 0)
        same_label_neighbors = labels == node_label 
        denominator[i] = same_label_neighbors.sum().item()
        # print(denominator[i])
        positive_weights = (adj[i] <= 0) & same_label_neighbors
        numerator[i] = positive_weights.sum().item()  
        # print(numerator[i]) 
    # with torch.no_grad():
    #     denominator[denominator == 0] = 1
    metric = numerator
    metric = metric.mean().item()
    
    return metric

def negative_metric(adj, labels):
    num_nodes = adj.shape[0]
    numerator = torch.zeros(num_nodes)
    denominator = torch.zeros(num_nodes)
    adj = adj.to_dense()
    for i in range(num_nodes):
        node_label = labels[i]
        # neighbors = (adj[i] != 0)
        different_label_neighbors = labels != node_label
        denominator[i] = different_label_neighbors.sum().item()
        negative_weights = (adj[i] >= 0) & different_label_neighbors
        numerator[i] = negative_weights.sum().item()
    metric = numerator 
    metric = metric.mean().item()
    
    return metric

def sb_metric(adj, labels):
    positive = postive_metric(adj,labels)
    negative = negative_metric(adj, labels)
    sb = (positive + negative )/2
    print(f'p: {positive:.4f}, n:{negative:.4f}, sb:{sb:.4f}.')
    return sb


def adj_vis(adj, model, data='sbm'):
    adj = adj.cpu()
    plt.imshow(adj, cmap='hot', interpolation='nearest')
    plt.colorbar() 
    plt.title(f'{model},{data} Heatmap')
    # plt.xlabel('X Axis')
    # plt.ylabel('Y Axis')
    filename = f'{model}_{data}_heatmap.png'
    plt.savefig(filename, format='png') 

def plot_features_scanpy(features,labels,layers,model):
    # Convert the tensor to a numpy array if it's not already
    features_np = features.detach().cpu().numpy()  # Ensure it's on CPU and detached from the graph

    # Initialize t-SNE
    tsne = TSNE(n_components=2, random_state=42)

    # Fit and transform the features to 2-dimensional space
    features_reduced = tsne.fit_transform(features_np)

    # Plotting
    # Assume 'labels' is a tensor or numpy array of labels corresponding to the features
    unique_labels = torch.unique(labels)
    plt.figure(figsize=(8, 8))
    # Set the colormap to use for different labels
    cmap = plt.cm.get_cmap('viridis', len(unique_labels))

    for i, label in enumerate(unique_labels):
        indices = (labels == label).detach().cpu().numpy()  # Convert to numpy array for indexing
        plt.scatter(features_reduced[indices, 0], features_reduced[indices, 1], s=50, c=[cmap(i)], label=str(label.item()), alpha=0.6)
        # Calculate the common range for the axes
    x_min, x_max = features_reduced[:, 0].min() - 1, features_reduced[:, 0].max() + 1
    y_min, y_max = features_reduced[:, 1].min() - 1, features_reduced[:, 1].max() + 1
    common_min = min(x_min, y_min)
    common_max = max(x_max, y_max)
    plt.xlim(common_min, common_max)
    plt.ylim(common_min, common_max)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.legend(fontsize=22)
    # plt.xlabel('class 0',fontsize=22)
    # plt.ylabel('class 1',fontsize=22)
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)

    # Save the figure
    plt.savefig(f'sbm_{model}_{layers}.pdf', transparent=True, bbox_inches='tight',pad_inches=0.03)
    plt.close()


def create_label_induced_negative_graph(y, idx_train=None):
    label_matrix = ~( y.unsqueeze(1) == y.unsqueeze(0) )
    label_matrix = 2*label_matrix.float()-1 # 将布尔值转换为浮点数，true为1，false为-1
    if idx_train is not None:
        if idx_train.size(0) < label_matrix.size(0):
            idx_not_train = torch.arange(idx_train.size(0), label_matrix.size(0))
        else:
            idx_not_train = ~idx_train
        label_matrix[idx_not_train, :] = 0  # Set rows not in idx_train to 0
        label_matrix[:, idx_not_train] = 0  # Set columns not in idx_train to 0
    return label_matrix


def calculate_patch_sim(x):
    """
    x size: [n_nodes, dim]
    return: average patch-wise similarity in a batch
    """
    n = x.size(0)
    norm_x = F.normalize(x, dim=-1)

    sim = norm_x @ norm_x.T     # [n_nodes, n_nodes]
    sim = torch.triu(sim, diagonal=1)
    sim = torch.sum(sim) / ((n**2 - n) / 2)
    return sim.item()


def calculate_erank(x):
    """
    x size: [n_nodes, dim]
    return: average erank in a batch
    """
    _, S, _ = torch.linalg.svd(x)
    N = torch.linalg.norm(S, ord=1, keepdim=True)
    S = S / N
    erank = torch.exp(torch.sum(-S * torch.log(S)))
    return erank.item()


def cal_variance(x):
    v = torch.sum(torch.var(x, dim=-1))
    return v.item()


def dropedge(adj, p=0.5):
    adj = adj.coalesce()
    # import pdb; pdb.set_trace()
    indices = adj.indices()
    n_node = len(indices[0])
    chosen_idx = torch.randperm(n_node, device=indices.device)[:int(p*n_node)]
    adj = torch.sparse_coo_tensor(indices[:, chosen_idx], adj.values()[chosen_idx], size=adj.size())
    return adj

def op_aug_normalized_adjacency(adj):
    # Create an identity matrix of the same size as the adjacency matrix and ensure it's on the same device
    eye = sparse_eye(adj.shape[0]).to(adj.device)
    # Add the identity matrix to the adjacency matrix to include self-connections
    adj = adj + eye
    # Convert to dense tensor for row sum operation if necessary
    if not isinstance(adj, torch.Tensor):
        adj = adj.to_dense()
    # Compute the sum of each row (degree of each node)
    row_sum = adj.sum(dim=1)
    # Compute the inverse square root of the node degrees
    d_inv_sqrt = row_sum.pow(-0.5)
    # Replace inf values with 0s (needed if there are isolated nodes)
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
    # Create a diagonal matrix from the inverse square roots
    d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
    # Compute the normalized adjacency matrix and return it
    return torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)

# preprocessing stage
def center_precompute(features, adj, T, K):
    # integration with the K layers
    if K == 0 or T == 0.:
        return features, 0.
    # delta = 1.0
    t = perf_counter()
    # eye = sparse_eye(adj.shape[0]).to(adj.device)
    n = adj.size(0)
    center = torch.full((n,n),1/n).to(adj.device)
    T = 1.0 
    op =  - center + adj
    for i in range(K):
        features = torch.spmm(op, features)
        if i % 5 ==0:
            sim = calculate_patch_sim(features)
            var = cal_variance(features)
            print(f'{i}: {sim:.4f}, {var:.4f}')
    precompute_time = perf_counter()-t
    return features, precompute_time

# preprocessing stage
def sgc_precompute(features, adj, T, K, labels=None):
    # integration with the K layers
    # plot_features_scanpy(features, labels, 0, 'SGC')
    if K == 0 or T == 0.:
        return features, 0.
    # delta = 1.0
    t = perf_counter()
    # eye = sparse_eye(adj.shape[0]).to(adj.device)
    op = adj
    # sb = sb_metric(adj,labels)
    # adj_vis(op,'sgc')
    # print(f'sgc_sb:{sb:.4f}')
    for i in range(K):
        if True:
            sim = calculate_patch_sim(features)
            var = cal_variance(features)
            # print(f'{i}: {sim:.4f}, {var:.4f}')
        features = torch.spmm(op, features)
        # features = nn.functional.normalize(features, dim=1)
        # if labels is not None and (i in [1, 5, 10, 20] or i % 50 ==0):
        # if i == K-1:
        #     plot_features_scanpy(features, labels, i, 'SGC')
        
    precompute_time = perf_counter()-t
    return features, precompute_time


# preprocessing stage
def sign_precompute(features, adj, T, K, b, labels=None): # adj sparse, feature dense
    t = perf_counter()
    tep=1
    # sparse
    # plot_features(features, labels, 0, 'Sign')
    eye = sparse_eye(adj.shape[0]).to(adj.device)
    neg = nn.functional.normalize(features, dim=1)
    neg = - neg @ neg.T/tep
    neg = nn.functional.softmax(neg, dim=1) 
    for i in range(K):
        update = - b * neg + adj
        op = (1-T)* update  + T* eye 
        # if i == K-1:
        #     sb = sb_metric(update,labels)
        features = torch.spmm(op, features)
        features = features - features.mean(dim=1, keepdim=True)
        features = nn.functional.normalize(features,dim=1)
        # if labels is not None and (i in [1, 5, 10, 20] or i % 50 ==0):
        # if i == K-1:
        #     plot_features_scanpy(features, labels, i, 'Sign')
       
    precompute_time = perf_counter() - t
    return features, precompute_time


# preprocessing stage
def label_precompute(features, adj, labels, idx_train,  T, K, b):
    # integration with the forward Euler scheme by default
    # if K == 0 or T == 0.:
    #     return features, 0.
    # plot_features_scanpy(features, labels, 0, 'Label')
    t = perf_counter()
    neg = create_label_induced_negative_graph(labels, idx_train)
    # neg = create_cn_induced_negative_graph(labels, idx_train)
    # neg = torch.nn.functional.softmax(neg, dim=1)
    neg = nn.functional.softmax(neg, dim=1)
    eye = sparse_eye(adj.shape[0]).to(adj.device)
    # b = 1.0
    update = - b * neg + adj
    # update = - b * neg + adj@adj.T
    op = (1 - T) * update + T * eye
    # sb = sb_metric(update,labels)
    # adj_vis(op,'label')
    # print(f'label_sb:{sb:.4f}')
    # for T in np.arange(0.1, 1, 0.1):
    #     op = (1 - T) * update + T * eye
    #     eigenvalues, eigenvectors = torch.linalg.eig(op)
    #     max_eigenvalue = eigenvalues.abs().argmax()
    # breakpoint()
    for i in range(K):
        # if True:
        #     sim = calculate_patch_sim(features)
        #     var = cal_variance(features)
        #     print(f'{i}: {sim:.4f}, {var:.4f}')
        features = torch.spmm(op, features)
        features = features - features.mean(dim=1, keepdim=True)
        features = nn.functional.normalize(features,dim=1)
        # if i in [1, 5, 10, 20, 50, 100, 150, 199]:
        # if i == K-1:
        #     plot_features_scanpy(features, labels, i, 'Label')
    precompute_time = perf_counter()-t
    return features, precompute_time


# preprocessing stage
def base_precompute(features, adj, T, K, mode, labels=None):
    # integration with the K layers
    # if K == 0 or T == 0.:
    #     return features, 0.
    # delta = 1.0
    # breakpoint()
    x = features
    t = perf_counter()
    if mode == 'drop':
        adj = dropedge(adj,T)
    x0 = x
    # eye = sparse_eye(adj.shape[0]).to(adj.device)
    for i in range(K):
        x_old = x
        x = torch.spmm(adj, x)
        if mode == 'CN':
            norm_x = nn.functional.normalize(x, dim=1)
            sim = norm_x @ norm_x.T
            if adj.size(1) == 2:
                sim[adj[0], adj[1]] = -np.inf
            else:
                sim.masked_fill_(adj.to_dense() > 1e-5, -np.inf)
            sim = nn.functional.softmax(sim, dim=1)
            x_neg = sim @ x    
            x = (1 + T) * x - T * x_neg
        
        if mode == 'LN':
            x = x - x.mean(dim=1, keepdim=True)
            x = nn.functional.normalize(x, dim=1)
        
        if mode == 'BN':
            x = x - x.mean(dim=0, keepdim=True)
            x = nn.functional.normalize(x, dim=0)
            if i == K-1:
                # breakpoint()
                n = adj.shape[0]
                matrix_one = torch.ones((n, n)) / n
                matrix_one = matrix_one.to(adj.device)
                adj_dense = adj.to_dense() if adj.is_sparse else adj
                result = torch.matmul(matrix_one, adj_dense)
                result = adj_dense - result
                # sb = sb_metric(result,labels)
                # adj_vis(result,'BN')
                # print(f'{mode}_sb:{sb:.4f}')
            
        if mode == 'PN':
            col_mean = x.mean(dim=0)
            x = x - col_mean
            rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt() 
            x = T * x / rownorm_mean
        
        if mode == 'res':
            x = (1-T)* x_old + T * x

        if mode == 'appnp':
            x = (1-T)* x0 + T * x
            adj_dense = adj.to_dense() if adj.is_sparse else adj
            adj_pos = 0
            adj_neg =0
            for i in range(K):
                adj_pos += T**i *adj**i
            for i in range(K-1):
                adj_neg += T**(i+1) *adj**i
            results = adj_pos - adj_neg
            # sb = sb_metric(results,labels)
            # adj_vis(results,'appnp')
            # print(f'{mode}_sb:{sb:.4f}')
        if mode == 'jk-net' and i ==K-1:
            x = torch.cat( (x,x_old), dim=-1)
            # sb = sb_metric(result,labels)
            # adj_vis(result,'BN')
            # print(f'{mode}_sb:{sb:.4f}')

        if mode == 'dagnn':
            x = 0.34*x + 0.33*x_old + 0.33*x0
            n = adj.shape[0]
            matrix_one = torch.ones((n, n)) / n
            matrix_one = matrix_one.to(adj.device)
            results = 0.33* matrix_one + 0.33* adj + 0.34* adj*adj
            # sb = sb_metric(results,labels)
            # adj_vis(results,'dagnn')
            # print(f'{mode}_sb:{sb:.4f}')
        # if i % 5 ==0:
        #     sim = calculate_patch_sim(x)
        #     var = cal_variance(x)
        #     print(f'{i}: {sim:.4f}, {var:.4f}')
    precompute_time = perf_counter()-t
    return x, precompute_time

# # preprocessing stage
# def dgc_precompute(features, adj, T, K):
#     # integration with the forward Euler scheme by default
#     if K == 0 or T == 0.:
#         return features, 0.
#     delta = T / K
#     t = perf_counter()
#     eye = sparse_eye(adj.shape[0]).to(adj.device)
#     op = (1 - delta) * eye + delta * adj
#     for i in range(K):
#         features = torch.spmm(op, features)
#     precompute_time = perf_counter()-t
#     return features, precompute_time



# classification stage
class DGC(nn.Module):
    """
    A Simple PyTorch Implementation of Logistic Regression.
    Assuming the features have been preprocessed with k-step graph propagation.
    Same as SGC's classification head
    """
    def __init__(self, nfeat, nclass):
        super(DGC, self).__init__()

        self.W = nn.Linear(nfeat, nclass)

    def forward(self, x):
        return self.W(x)


if __name__ == "__main__":
    y1 = torch.tensor([0, 0, 0])
    idx_train1 = torch.tensor([0,1])

    # Example 2: Multiclass labels
    y2 = torch.tensor([0, 1, 2, 1, 0])
    idx_train2 = torch.tensor([1, 2, 4])

    # Example 3: No training indices provided
    y3 = torch.tensor([2, 2, 0, 1, 0])
    idx_train3 = torch.tensor([0,1,2])

    # Testing the function with examples
    result1 = create_label_induced_negative_graph(y1, idx_train1)
    # result2 = create_label_induced_negative_graph(y2, idx_train2)
    # result3 = create_label_induced_negative_graph(y3)
    # result4 = create_label_induced_negative_graph(y3, idx_train3)

    print("Example 1 result:")
    print(result1)

    # print("\nExample 2 result:")
    # print(result2)

    # print("\nExample 3 result:")
    # print(result3)

    # print("\nExample 4 result:")
    # print(result4)
