import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from dataset import process_dataset, get_dataset
from models import *
from utils import *
from runner import *
from config import args
from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, classification_report
from sklearn.feature_selection import mutual_info_classif
from scipy.stats import pointbiserialr
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score

def analyze_feature_label_correlation(data):
    x = data.x[data.train_mask].detach().cpu().numpy()
    y = data.y[data.train_mask].detach().cpu().numpy()
    """
    计算每个特征与标签的互信息和点二列相关系数
    参数:
        x: numpy array [num_nodes, feat_dim]
        y: numpy array [num_nodes]
    返回:
        mi: 互信息 array
        corr_list: 点二列相关系数 list
    """
    mi = mutual_info_classif(x, y)
    corr_list = []
    for i in range(x.shape[1]):
        corr, _ = pointbiserialr(x[:, i], y)
        corr_list.append(corr)
    corr = np.array(corr_list)
    print("每个特征与标签的互信息:", mi)
    print("每个特征与标签的点二列相关:", corr)
    return



def compute_silhouette(data):
    x = data.x[data.train_mask].detach().cpu().numpy()
    labels = data.y[data.train_mask].detach().cpu().numpy()
    """
    计算Silhouette聚类分数，labels为类别标签
    """
    score = silhouette_score(x, labels)
    print(f"Silhouette score: {score:.4f}")
    return score


def pca_visualization(data, save_path="pca_node_feature_by_label.png"):
    x = data.x[data.train_mask].detach().cpu().numpy()
    y = data.y[data.train_mask].detach().cpu().numpy()
    """
    对高维特征做PCA降到2维，并可视化按标签上色
    """
    pca = PCA(n_components=2, random_state=42)
    x_pca = pca.fit_transform(x)
    plt.figure(figsize=(8, 6))
    for lab in set(y):
        idx = (y == lab)
        plt.scatter(x_pca[idx, 0], x_pca[idx, 1], label=f"label={lab}", s=8)
    plt.legend()
    plt.title("PCA of Node Features by Label")
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=200)

def tsne_anaylysis(data):
    x_np = data.x[data.train_mask].detach().cpu().numpy()      # 节点特征 [num_nodes, feat_dim]
    y_np = data.y[data.train_mask].detach().cpu().numpy()      # 标签 [num_nodes]

    # t-SNE降维
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    x_2d = tsne.fit_transform(x_np)

    # 可视化不同标签
    plt.figure(figsize=(8, 6))
    for lab in set(y_np):
        idx = (y_np == lab)
        plt.scatter(x_2d[idx, 0], x_2d[idx, 1], label=f"label={lab}", s=8)
    plt.legend()
    plt.title("t-SNE of Node Features by Label")
    plt.tight_layout()
    plt.savefig("tsne_node_feature_by_label.png", dpi=200)


def baseline_analysis(data):
    # 取出特征和标签（假定全部转到cpu）
    x_np = data.x[data.train_mask].detach().cpu().numpy()
    y_np = data.y[data.train_mask].detach().cpu().numpy()

    # 可选：只用 train_mask 做训练，test_mask 做测试
    train_idx = data.train_mask.cpu().numpy()
    test_idx = data.test_mask.cpu().numpy()

    X_train, y_train = x_np[train_idx], y_np[train_idx]
    X_test, y_test = x_np[test_idx], y_np[test_idx]

    # 逻辑回归
    logreg = LogisticRegression(max_iter=1000, class_weight="balanced").fit(X_train, y_train)
    y_pred_prob_logreg = logreg.predict_proba(X_test)[:,1]
    y_pred_logreg = (y_pred_prob_logreg > 0.5).astype(int)

    # 决策树
    tree = DecisionTreeClassifier(class_weight="balanced", random_state=42).fit(X_train, y_train)
    y_pred_prob_tree = tree.predict_proba(X_test)[:,1]
    y_pred_tree = (y_pred_prob_tree > 0.5).astype(int)

    # 评估指标
    print("=== Logistic Regression ===")
    print("AUC:", roc_auc_score(y_test, y_pred_prob_logreg))
    print("Accuracy:", accuracy_score(y_test, y_pred_logreg))
    print("F1:", f1_score(y_test, y_pred_logreg))
    print(classification_report(y_test, y_pred_logreg, digits=3))

    print("=== Decision Tree ===")
    print("AUC:", roc_auc_score(y_test, y_pred_prob_tree))
    print("Accuracy:", accuracy_score(y_test, y_pred_tree))
    print("F1:", f1_score(y_test, y_pred_tree))
    print(classification_report(y_test, y_pred_tree, digits=3))

def fairness_analysis(data):
    # 取出特征和标签（假定全部转到cpu）
    x_np = data.x[data.train_mask].detach().cpu().numpy()
    y_np = data.y[data.train_mask].detach().cpu().numpy()
    sens_np = data.sens_labels[data.train_mask].detach().cpu().numpy()
    # x_np: [N, d]  y_np: [N]  sens_np: [N] or [N, k]
    N = x_np.shape[0]
    d = x_np.shape[1]
    print(f"数据形状: x {x_np.shape}, y {y_np.shape}, sens {sens_np.shape}")
    
    # 标签分布
    unique_y, counts_y = np.unique(y_np, return_counts=True)
    print("标签分布:")
    for val, count in zip(unique_y, counts_y):
        print(f"label={val}: {count} ({count/N:.3f})")

    # 敏感属性分析
    if sens_np.ndim == 1:
        sens_np = sens_np.reshape(-1, 1)
    n_sens = sens_np.shape[1]

    for idx in range(n_sens):
        sens_col = sens_np[:, idx]
        unique_s, counts_s = np.unique(sens_col, return_counts=True)
        print(f"\n敏感属性 sens_{idx} 分布:")
        for val, count in zip(unique_s, counts_s):
            print(f"  sens={val}: {count} ({count/N:.3f})")

        # 联合分布
        print(f"\n标签-敏感属性联合分布 (sens_{idx}):")
        for s in unique_s:
            for y in unique_y:
                p = np.mean((sens_col == s) & (y_np == y))
                print(f"  P(y={y}|sens={s}) = {p/np.mean(sens_col==s):.3f} ({np.sum((sens_col==s)&(y_np==y))} samples)")
        # Demographic parity
        p_y1_given_s = [np.mean(y_np[sens_col==s]) for s in unique_s]
        print(f"Demographic Parity (sens_{idx}): {np.abs(p_y1_given_s[0] - p_y1_given_s[1]):.4f}" if len(unique_s)==2 else f"DP={np.ptp(p_y1_given_s):.4f}")
        
        p1 = np.mean(y_np[sens_col == 1])
        p0 = np.mean(y_np[sens_col == 0])
        disparate_impact = p1 / p0 if p0 > 0 else np.nan
        print(f"Disparate Impact: {disparate_impact:.3f}")

def edge_analysis(data):
    edge_index = data.edge_index.detach().cpu().numpy()
    N = edge_index.max() + 1  # 假设节点编号为0~N-1
    degrees = np.bincount(edge_index[0], minlength=N) + np.bincount(edge_index[1], minlength=N)

    print(f"节点总数: {N}")
    print(f"边总数: {edge_index.shape[1]}")
    print(f"节点度均值: {degrees.mean():.2f}, 最大度: {degrees.max()}, 最小度: {degrees.min()}")

    sens_labels = data.sens_labels.detach().cpu().numpy()  # 长度N, 0/1或多值
    src, tgt = edge_index  # 长度为num_edges

    same_group = sens_labels[src] == sens_labels[tgt]
    diff_group = ~same_group

    same_ratio = same_group.sum() / len(same_group)
    diff_ratio = diff_group.sum() / len(diff_group)
    print(f"同敏感组边比例: {same_ratio:.3f}, 跨敏感组边比例: {diff_ratio:.3f}")

if __name__ == '__main__':
    # seed_everything(args.seed)
    # 获得source data，其中少量标签
    source_data = get_dataset(args, args.inid)
    # 获得target data，其中没有标签
    target_data = get_dataset(args, args.outid)
    # 处理两个数据集的数据
    process_dataset(args, target_data)
    print("============source_data==============")
    edge_analysis(source_data)
    fairness_analysis(source_data)
    print("============target_data==============")
    edge_analysis(target_data)
    fairness_analysis(target_data)
    # tsne_anaylysis(target_data)
    # baseline_analysis(target_data)
    # analyze_feature_label_correlation(target_data)
    # pca_visualization(target_data)
    # compute_silhouette(target_data)