import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.manifold import TSNE

from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression, Perceptron
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import seaborn as sns

import warnings
from sklearn.exceptions import DataConversionWarning

# 取消特定警告
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings("ignore", category=FutureWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=DataConversionWarning)


def cal_heterophilous_ratio(M_hat, avg_D, l=1):
    M_power_l = np.linalg.matrix_power(M_hat, l)
    D_item = np.sqrt(avg_D / 2)
    M_power_l_times_D = M_power_l * D_item
    diff = M_power_l_times_D[:, np.newaxis] - M_power_l_times_D
    dist_matrix = np.sqrt(np.sum(diff**2, axis=-1))
    
    return dist_matrix

def softmax(x):
    # 计算指数
    exp_x = np.exp(x)
    
    # 计算分母（总和）
    sum_exp_x = np.sum(exp_x, axis=1, keepdims=True)
    
    # 计算Softmax值
    softmax_x = exp_x / sum_exp_x
    
    return softmax_x

def LR_test(X, y):
    model = LogisticRegression()    
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def Perceptron_test(X, y):
    model = Perceptron()
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def SVM_test(X, y):
    model = LinearSVC()
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def Bayes_test(X, y, M_hat, l):
    
    # 计算mu_k^{(l)}
    M_power_l = np.linalg.matrix_power(M_hat, l)
    mus_l = M_power_l @ mus
    bias = -0.5 * np.linalg.norm(mus_l, axis=1)
    
    # 预测
    y_score = softmax(X @ mus_l.T + bias)
    y_pred = np.argmax(y_score, axis=1)
    
    # 计算准确率
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

# Generate N samples belonging to 5 Classes.
num_nodes_per_class = 500 * 2
num_nodes = num_nodes_per_class * 5
c1_X = np.random.normal([1, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
c2_X = np.random.normal([0, 1, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
c3_X = np.random.normal([0, 0, 1, 0, 0], 0.6, (num_nodes_per_class, 5))
c4_X = np.random.normal([0, 0, 0, 1, 0], 0.6, (num_nodes_per_class, 5))
c5_X = np.random.normal([0, 0, 0, 0, 1], 0.6, (num_nodes_per_class, 5))
mus = np.eye(5)
c1_y = np.ones(num_nodes_per_class) * 0
c2_y = np.ones(num_nodes_per_class) * 1
c3_y = np.ones(num_nodes_per_class) * 2
c4_y = np.ones(num_nodes_per_class) * 3
c5_y = np.ones(num_nodes_per_class) * 4
X = np.concatenate((c1_X, c2_X, c3_X, c4_X, c5_X), axis=0)
y = np.concatenate((c1_y, c2_y, c3_y, c4_y, c5_y), axis=0)

a = 0.13    # [0,0.333333]
b = 2 * a
c = 1 / 3 - a
M_hat = np.array([[a, b, c, c, c],
              [c, a, b, c, c],
              [c, c, a, b, c],
              [c, c, c, a, b],
              [b, c, c, c, a]])

avg_D_s = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 325, 350]
output_degree = []
output_acc = []
output_max_dis = []
output_min_dis = []
output_cor = []
for avg_d_i in avg_D_s:
    avg_d = avg_d_i / num_nodes_per_class
    M = M_hat * avg_d
    avg_D = np.array([[avg_d_i], [avg_d_i], [avg_d_i], 
                    [avg_d_i], [avg_d_i]])
    edge_prob = np.random.uniform(0, 1, size=(num_nodes, num_nodes))
    A = np.where(edge_prob < M[y.astype(np.int64)[:, np.newaxis], y.astype(np.int64)], 1, 0)
    if avg_d_i == 0:
        acc, pred = model_function(X, y)
    else:
        inv_D = np.diag(1./ np.sum(A, axis=1))
        S = inv_D @ A
        GCN_X = S @ X
        model_function = LR_test
        acc, pred = model_function(X, y)
        gc1_acc, gc1_pred = model_function(GCN_X, y)
        cm_mlp = confusion_matrix(y, pred)
        cm_gcn = confusion_matrix(y, gc1_pred)
        gc1_class_dis = cal_heterophilous_ratio(M_hat, avg_D, 1)

        cm_diff = cm_gcn - cm_mlp

        gc1_class_dis_flatten = []
        cm_diff_flatten = []
        for i in range(5):
            for j in range(5):
                if i == j:
                    continue
                gc1_class_dis_flatten.append(gc1_class_dis[i][j])
                cm_diff_flatten.append(cm_diff[i][j])
        gc1_class_dis_flatten = np.array(gc1_class_dis_flatten)
        cm_diff_flatten = np.array(cm_diff_flatten)

        correlation_coefficient = np.corrcoef(gc1_class_dis_flatten, cm_diff_flatten)[0, 1]
        print(f"avg_D: {avg_d_i:.2f}, acc: {gc1_acc*100:.2f}, min_dis:{gc1_class_dis_flatten.min():.4f}, correlation: {correlation_coefficient:.4f}")
        output_degree.append(avg_d_i)
        output_acc.append(gc1_acc)
        output_min_dis.append(gc1_class_dis_flatten.min())
        output_max_dis.append(gc1_class_dis_flatten.max())
        output_cor.append(correlation_coefficient)
output_degree = np.array(output_degree)
output_acc = np.array(output_acc)
output_min_dis = np.array(output_min_dis)
output_max_dis = np.array(output_max_dis)
output_cor = np.array(output_cor)
np.savez("degree_accs_22.npz", degree=output_degree, acc=output_acc, min_dis=output_min_dis, max_dis=output_max_dis, cor=output_cor)