import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.manifold import TSNE

from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression, Perceptron
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import seaborn as sns
import time

import warnings
from sklearn.exceptions import DataConversionWarning

# 取消特定警告
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings("ignore", category=FutureWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=DataConversionWarning)

# cur_dtype = np.float16
# CDT = 'fp16'

# cur_dtype = np.float32
# CDT = 'fp32'

# cur_dtype = np.float64
# CDT = 'fp64'

cur_dtype = np.float128
CDT = 'fp128'

# 1.0000000000000000001
# 1.0000001
# 1.0000000000000002
# 1.0000000000000000001
a = 0.20


num_nodes_per_class = 200
num_deg = 50
print(f"layer_accs_a{a}_{CDT}_n{num_nodes_per_class}_d{num_deg}")
# print(1 + np.finfo(np.longdouble).eps)
print(1 + np.finfo(cur_dtype).eps)

def cal_heterophilous_ratio_MultiLayerVersion(M_hat, S_power_l, l=1):
    M_power_l = np.linalg.matrix_power(M_hat, l)
    diff = M_power_l[:, np.newaxis] - M_power_l
    dist_matrix = np.sqrt(np.sum(diff**2, axis=-1))
    # cal n
    num_nodes = S_power_l.shape[0]
    t0 = time.perf_counter()
    diff_squared = np.zeros((num_nodes, num_nodes)).astype(cur_dtype)
    for i in range(num_nodes):
        for j in range(num_nodes):
            diff_squared[i][j] = np.sum((S_power_l[i] - S_power_l[j]) ** 2)
    t1 = time.perf_counter()
    print(diff_squared.dtype)
    print(S_power_l.dtype)
    Q_norm = np.sqrt(np.sum(diff_squared))
    Q_norm_approx = np.sqrt(np.sum(diff**2) * num_nodes / 5)
    # print("Q_norm_approx", Q_norm_approx)
    # print("Q_norm/Q_norm_approx", Q_norm / Q_norm_approx)
    # print("Est STD:", 0.6 * Q_norm / num_nodes / np.sqrt(2))
    print("Est STD approx:", 0.6 * Q_norm_approx / num_nodes / np.sqrt(2))
    print("Est Mean Dis.:", np.mean(dist_matrix))
    # print("Q_norm", Q_norm)
    print("Q_norm_approx", Q_norm_approx)
    # print("ratio Q/std:", Q_norm/mean_class_std)
    return dist_matrix * num_nodes / Q_norm

def softmax(x):
    # 计算指数
    exp_x = np.exp(x)
    
    # 计算分母（总和）
    sum_exp_x = np.sum(exp_x, axis=1, keepdims=True)
    
    # 计算Softmax值
    softmax_x = exp_x / sum_exp_x
    
    return softmax_x

def LR_test(X, y):
    model = LogisticRegression(max_iter=1000)
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def Perceptron_test(X, y):
    model = Perceptron()
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def SVM_test(X, y):
    model = LinearSVC()
    model.fit(X, y)
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred)
    return acc, y_pred

def Bayes_test(X, y, M_hat, l):
    
    # 计算mu_k^{(l)}
    M_power_l = np.linalg.matrix_power(M_hat, l)
    mus_l = M_power_l @ mus
    bias = -0.5 * np.linalg.norm(mus_l, axis=1)
    
    # 预测
    y_score = softmax(X @ mus_l.T + bias)
    y_pred = np.argmax(y_score, axis=1)
    
    # 计算准确率
    acc = accuracy_score(y, y_pred)
    return acc, y_pred


# Generate N samples belonging to 5 Classes.

num_nodes = num_nodes_per_class * 5
c1_X = np.random.normal([1, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
c2_X = np.random.normal([0, 1, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
c3_X = np.random.normal([0, 0, 1, 0, 0], 0.6, (num_nodes_per_class, 5))
c4_X = np.random.normal([0, 0, 0, 1, 0], 0.6, (num_nodes_per_class, 5))
c5_X = np.random.normal([0, 0, 0, 0, 1], 0.6, (num_nodes_per_class, 5))

# c1_X = np.random.normal([0, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
# c2_X = np.random.normal([0, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
# c3_X = np.random.normal([0, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
# c4_X = np.random.normal([0, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
# c5_X = np.random.normal([0, 0, 0, 0, 0], 0.6, (num_nodes_per_class, 5))
mus = np.eye(5)
c1_y = np.ones(num_nodes_per_class) * 0
c2_y = np.ones(num_nodes_per_class) * 1
c3_y = np.ones(num_nodes_per_class) * 2
c4_y = np.ones(num_nodes_per_class) * 3
c5_y = np.ones(num_nodes_per_class) * 4
X = np.concatenate((c1_X, c2_X, c3_X, c4_X, c5_X), axis=0).astype(cur_dtype)
y = np.concatenate((c1_y, c2_y, c3_y, c4_y, c5_y), axis=0).astype(cur_dtype)

b = 2 * a
c = 1 / 3 - a
M_hat = np.array([[a, b, c, c, c],
              [c, a, b, c, c],
              [c, c, a, b, c],
              [c, c, c, a, b],
              [b, c, c, c, a]])
print(M_hat)
print()

deg_a = int(a * num_deg) 
deg_b = int(b * num_deg)
deg_c = int(c * num_deg)
if (deg_a  + deg_b + 3 * deg_c) != num_deg:
    deg_a = int(a * num_deg) + 1
    deg_b = int(b * num_deg)
    deg_c = int(c * num_deg)
if (deg_a  + deg_b + 3 * deg_c) != num_deg:
    deg_a = int(a * num_deg) + 1
    deg_b = int(b * num_deg) + 1
    deg_c = int(c * num_deg)

assert ((deg_a  + deg_b + 3 * deg_c) == num_deg)

new_a = deg_a / num_deg
new_b = deg_b / num_deg
new_c = deg_c / num_deg

M_hat = np.array([[new_a, new_b, new_c, new_c, new_c],
              [new_c, new_a, new_b, new_c, new_c],
              [new_c, new_c, new_a, new_b, new_c],
              [new_c, new_c, new_c, new_a, new_b],
              [new_b, new_c, new_c, new_c, new_a]])

M_hat_deg = np.array([[deg_a, deg_b, deg_c, deg_c, deg_c],
              [deg_c, deg_a, deg_b, deg_c, deg_c],
              [deg_c, deg_c, deg_a, deg_b, deg_c],
              [deg_c, deg_c, deg_c, deg_a, deg_b],
              [deg_b, deg_c, deg_c, deg_c, deg_a]])
avg_D = np.array([[num_deg], [num_deg], [num_deg], 
                [num_deg], [num_deg]])

row_index = np.arange(0, num_nodes_per_class, 1)

print(row_index.dtype)
A = np.zeros((num_nodes, num_nodes))
print(A.shape)
for k_1 in range(5):
    for i in range(num_nodes_per_class):
        i_base = k_1 * num_nodes_per_class
        for k_2 in range(5):
            neighbor_base = k_2 * num_nodes_per_class
            neighbor_idx = np.random.permutation(row_index)[:M_hat_deg[k_1,k_2]] + neighbor_base
            A[i+i_base, neighbor_idx]=1
A = A.astype(cur_dtype)
inv_D = np.diag(1./ np.sum(A, axis=1))
S = inv_D @ A
S = S.astype(cur_dtype)


# b = 2 * a
# c = 1 / 3 - a
# M_hat = np.array([[a, b, c, c, c],
#             [c, a, b, c, c],
#             [c, c, a, b, c],
#             [c, c, c, a, b],
#             [b, c, c, c, a]])
# M_hat = M_hat.astype(cur_dtype)
# avg_d = num_deg / num_nodes_per_class
# M = M_hat * avg_d
# avg_D = np.array([[num_nodes_per_class * avg_d], [num_nodes_per_class * avg_d], [num_nodes_per_class * avg_d], 
#                 [num_nodes_per_class * avg_d], [num_nodes_per_class * avg_d]])
# edge_prob = np.random.uniform(0, 1, size=(num_nodes, num_nodes))
# A = np.where(edge_prob < M[y.astype(np.int64)[:, np.newaxis], y.astype(np.int64)], 1, 0)
# A = A.astype(cur_dtype)
# inv_D = np.diag(1./ np.sum(A, axis=1))
# S = inv_D @ A
# S = S.astype(cur_dtype)

n_layers = [1] + np.arange(2, 101, 2).tolist()
output_layer = []
output_acc = []
output_max_dis = []
output_min_dis = []
output_cor = []
output_mean_class_std = []
output_mean_abs_node_feature = []
output_mean_class_mean_dis = []

for n_l in n_layers:
    S_power_l = np.linalg.matrix_power(S, n_l)
    GCN_X_l = S_power_l @ X
    
    deg_hop_L = np.sum(S_power_l != 0, axis=1)
    class_averages = []
    class_stds = []
    c_degs = []
    # 遍历每个类别，计算特征的平均值
    for y_index in range(5):
        class_indices = np.where(y == y_index)[0] 
        class_X = GCN_X_l[class_indices]
        class_average = np.mean(class_X, axis=0)
        class_std = np.std(class_X, axis=0)
        class_averages.append(class_average)
        class_stds.append(class_std)
        c_deg = deg_hop_L[y==y_index]
        c_degs.append(c_deg.mean(0))
    class_averages = np.array(class_averages)
    class_stds = np.array(class_stds)
    # print("Degree")
    # print(c_degs)
    mean_class_std = np.mean(class_stds)
    mean_abs_node_feature = np.mean(np.abs(GCN_X_l))
    print("mean_abs_features:", mean_abs_node_feature)
    print("ratio:", mean_abs_node_feature / mean_class_std)
    # print("Mean")
    # print(class_averages)
    mean_diff = class_averages[:, np.newaxis] - class_averages
    class_diff_matrix = np.sqrt(np.sum(mean_diff**2, axis=-1))
    mean_class_mean_dis = np.mean(class_diff_matrix)
    print("mean_class_std:", mean_class_std, mean_class_std.dtype)
    print("mean_class_mean_dis:", mean_class_mean_dis)
    output_mean_class_std.append(mean_class_std)
    output_mean_abs_node_feature.append(mean_abs_node_feature)
    output_mean_class_mean_dis.append(mean_class_mean_dis)
    F_approx = class_diff_matrix / mean_class_std / np.sqrt(2) * 0.6
    # print("Mean @ Mean.T")
    # print(class_averages @ class_averages.T)
    # print("M_hat_L")
    # M_hat_power_L = np.linalg.matrix_power(M_hat, L)
    # print(M_hat_power_L)
    # print("M_hat_L @ M_hat_L.T")
    # print(M_hat_power_L @ M_hat_power_L.T)
    # print()
    # print(class_averages @ class_averages.T - M_hat_power_L @ M_hat_power_L.T)

    ######################################################################

    model_function = LR_test

    GCN_X_normalized = (GCN_X_l - np.mean(GCN_X_l, axis=0)) / np.std(GCN_X_l, axis=0)

    acc, pred = model_function(X, y)
    gc1_acc, gc1_pred = model_function(GCN_X_normalized, y)

    cm_mlp = confusion_matrix(y, pred)
    cm_gcn = confusion_matrix(y, gc1_pred)
    gc1_class_dis = cal_heterophilous_ratio_MultiLayerVersion(M_hat, S_power_l, n_l)
    # gc1_class_dis = F_approx
    # 计算混淆矩阵差
    cm_diff = cm_gcn - cm_mlp

    gc1_class_dis_flatten = []
    cm_diff_flatten = []
    for i in range(5):
        for j in range(5):
            if i == j:
                continue
            gc1_class_dis_flatten.append(gc1_class_dis[i][j])
            cm_diff_flatten.append(cm_diff[i][j])
    gc1_class_dis_flatten = np.array(gc1_class_dis_flatten)
    cm_diff_flatten = np.array(cm_diff_flatten)

    correlation_coefficient = np.corrcoef(gc1_class_dis_flatten, cm_diff_flatten)[0, 1]
    print(f"layer: {n_l:.2f}, acc: {gc1_acc*100:.2f}, min_dis:{gc1_class_dis_flatten.min():.4f}, max_dis:{gc1_class_dis_flatten.max():.4f}, correlation: {correlation_coefficient:.4f}")
    output_layer.append(n_l)
    output_acc.append(gc1_acc)
    output_min_dis.append(gc1_class_dis_flatten.min())
    output_max_dis.append(gc1_class_dis_flatten.max())
    output_cor.append(correlation_coefficient)
output_layer = np.array(output_layer)
output_acc = np.array(output_acc)
output_min_dis = np.array(output_min_dis)
output_max_dis = np.array(output_max_dis)
output_cor = np.array(output_cor)
output_mean_class_std = np.array(output_mean_class_std)
output_mean_abs_node_feature = np.array(output_mean_abs_node_feature)
output_mean_class_mean_dis = np.array(output_mean_class_mean_dis)
file_name = f"layer_accs_a{a}_{CDT}_n{num_nodes_per_class}_d{num_deg}.npz"
np.savez(file_name, layer=output_layer, acc=output_acc, min_dis=output_min_dis, max_dis=output_max_dis, cor=output_cor,
        mean_class_std=output_mean_class_std, mean_abs_node_feature=output_mean_abs_node_feature, mean_class_mean_dis=output_mean_class_mean_dis)
print(f"Saved in: {file_name}")