import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.manifold import TSNE

from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression, Perceptron
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import seaborn as sns

import warnings
from sklearn.exceptions import DataConversionWarning

# 取消特定警告
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings("ignore", category=FutureWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=DataConversionWarning)

def cal_heterophilous_ratio(M_hat, avg_D, delta, l=1):
    M_power_l = np.linalg.matrix_power(M_hat, l)
    
    D_item = np.sqrt(delta * delta / 0.18 + 2. / avg_D)
    M_power_l_times_D = M_power_l / D_item
    diff = M_power_l_times_D[:, np.newaxis] - M_power_l_times_D
    dist_matrix = np.sqrt(np.sum(diff**2, axis=-1))
    
    return dist_matrix

def softmax(x):
    # 计算指数
    exp_x = np.exp(x)
    
    # 计算分母（总和）
    sum_exp_x = np.sum(exp_x, axis=1, keepdims=True)
    
    # 计算Softmax值
    softmax_x = exp_x / sum_exp_x
    
    return softmax_x

def LR_test(X, y):
    model = LogisticRegression()    
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def Perceptron_test(X, y):
    model = Perceptron()
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def SVM_test(X, y):
    model = LinearSVC()
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def Bayes_test(X, y, M_hat, l):
    
    # 计算mu_k^{(l)}
    M_power_l = np.linalg.matrix_power(M_hat, l)
    mus_l = M_power_l @ mus
    bias = -0.5 * np.linalg.norm(mus_l, axis=1)
    
    # 预测
    y_score = softmax(X @ mus_l.T + bias)
    y_pred = np.argmax(y_score, axis=1)
    
    # 计算准确率
    acc = accuracy_score(y, y_pred)
    return acc, y_pred


# Generate N samples belonging to 5 Classes.
num_nodes_per_class = 1000
num_nodes = num_nodes_per_class * 5
c1_X = np.random.normal([1, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
c2_X = np.random.normal([0, 1, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
c3_X = np.random.normal([0, 0, 1, 0, 0], 0.6, (num_nodes_per_class, 5))
c4_X = np.random.normal([0, 0, 0, 1, 0], 0.6, (num_nodes_per_class, 5))
c5_X = np.random.normal([0, 0, 0, 0, 1], 0.6, (num_nodes_per_class, 5))
mus = np.eye(5)
c1_y = np.ones(num_nodes_per_class) * 0
c2_y = np.ones(num_nodes_per_class) * 1
c3_y = np.ones(num_nodes_per_class) * 2
c4_y = np.ones(num_nodes_per_class) * 3
c5_y = np.ones(num_nodes_per_class) * 4
X = np.concatenate((c1_X, c2_X, c3_X, c4_X, c5_X), axis=0)
y = np.concatenate((c1_y, c2_y, c3_y, c4_y, c5_y), axis=0).astype(int)
y_one_hot =  np.eye(5)[y].reshape(-1, 5)

a = 0.20    # [0,0.333333]
b = 2 * a
c = (1/3)-a
M_hat = np.array([[a, b, c, c, c],
              [c, a, b, c, c],
              [c, c, a, b, c],
              [c, c, c, a, b],
              [b, c, c, c, a]])

avg_d = 25 / num_nodes_per_class
M = M_hat * avg_d
avg_D = np.array([[num_nodes_per_class * avg_d], [num_nodes_per_class * avg_d], [num_nodes_per_class * avg_d], 
                  [num_nodes_per_class * avg_d], [num_nodes_per_class * avg_d]])
edge_prob = np.random.uniform(0, 1, size=(num_nodes, num_nodes))

noise_s = np.arange(0., 0.01, 0.0002).tolist()
output_delta = []
output_acc = []
output_max_dis = []
output_min_dis = []
output_cor = []
for delta in noise_s:

    node_neighborhood_perturbation = np.random.randn(num_nodes, 5) * delta
    # 第一步：生成随机符号：
    signs = [-1, -1, -1, 1, 1, 1]
    random_signs_matrix = np.array([np.random.permutation(signs) for _ in range(num_nodes)])[:, :5]
    # print(random_signs_matrix)
    node_neighborhood_perturbation = np.abs(node_neighborhood_perturbation) * random_signs_matrix

    # # node_neighborhood_perturbation = np.random.randn(num_nodes, 5) * delta
    # node_neighborhood_perturbation = np.random.uniform(-1 * delta, delta, size=(num_nodes, 4))
    # last_column = -1 * np.sum(node_neighborhood_perturbation, axis=1).reshape(-1, 1)
    # node_neighborhood_perturbation = np.concatenate((node_neighborhood_perturbation, last_column), axis=1)
    perturbation_prob = node_neighborhood_perturbation[np.arange(node_neighborhood_perturbation.shape[0])[:, np.newaxis], y.astype(int)]
    A = np.where(edge_prob - perturbation_prob < M[y.astype(np.int64)[:, np.newaxis], y.astype(np.int64)], 1, 0)
    deg = np.sum(A, axis=1)
    inv_D = np.diag(1./ deg)
    S = inv_D @ A

    # 重新绘制
    modified_node_local_distribution = S @ y_one_hot
    modified_m_hat_means, modified_m_hat_stds = [], []
    for c in range(5):
        modified_class_nld = modified_node_local_distribution[y==c]
        modified_m_hat_means.append(modified_class_nld.mean(0))
        modified_m_hat_stds.append(modified_class_nld.std(0))
        
    modified_m_hat_means = np.array(modified_m_hat_means)
    modified_m_hat_stds = np.array(modified_m_hat_stds)
    # print("M_hat MEAN")
    # print(modified_m_hat_means)
    # print("M_hat STD")
    # print(modified_m_hat_stds)

    GCN_X = S @ X

    model_function = LR_test

    acc, pred = model_function(X, y)
    gc1_acc, gc1_pred = model_function(GCN_X, y)
    cm_mlp = confusion_matrix(y, pred)
    cm_gcn = confusion_matrix(y, gc1_pred)
    empirical_delta = modified_m_hat_stds.min()
    # gc1_class_dis = cal_heterophilous_ratio(M_hat, avg_D, delta, 1)
    gc1_class_dis = cal_heterophilous_ratio(modified_m_hat_means, avg_D, empirical_delta, 1)
    # 计算混淆矩阵差
    cm_diff = cm_gcn - cm_mlp

    gc1_class_dis_flatten = []
    cm_diff_flatten = []
    for i in range(5):
        for j in range(5):
            if i == j:
                continue
            gc1_class_dis_flatten.append(gc1_class_dis[i][j])
            cm_diff_flatten.append(cm_diff[i][j])
    gc1_class_dis_flatten = np.array(gc1_class_dis_flatten)
    cm_diff_flatten = np.array(cm_diff_flatten)

    correlation_coefficient = np.corrcoef(gc1_class_dis_flatten, cm_diff_flatten)[0, 1]
    print(f"delta: {delta:.4f}, acc: {gc1_acc*100:.2f}, min_dis:{gc1_class_dis_flatten.min():.4f}, correlation: {correlation_coefficient:.4f}")
    output_delta.append(delta)
    output_acc.append(gc1_acc)
    output_min_dis.append(gc1_class_dis_flatten.min())
    output_max_dis.append(gc1_class_dis_flatten.max())
    output_cor.append(correlation_coefficient)
output_delta = np.array(output_delta)
output_acc = np.array(output_acc)
output_min_dis = np.array(output_min_dis)
output_max_dis = np.array(output_max_dis)
output_cor = np.array(output_cor)
np.savez("noise_accs.npz", delta=output_delta, acc=output_acc, min_dis=output_min_dis, max_dis=output_max_dis, cor=output_cor)