import random
import argparse
import numpy as np
import torch
import matplotlib.markers as mmarkers
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE
from sklearn import metrics
from matplotlib import rcParams

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def arg_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='computers', help='40/NTU')
    parser.add_argument('--f_dim', type=int, default=64)
    parser.add_argument('--seed', type=int, default=777)
    parser.add_argument('--lrate', type=float, default=0.001)
    parser.add_argument('--wdecay', type=float, default=0.00)

    parser.add_argument('--in_dim', type=int, default=128)
    parser.add_argument('--out_dim', type=int, default=2)
    parser.add_argument('--hid_dim', type=int, default=128)
    parser.add_argument('--num_edges', type=int, default=100)
    parser.add_argument('--min_num_edges', type=int, default=64)

    parser.add_argument('--k', type=int, default=5)
    parser.add_argument('--cuda', type=str, default='0', help='0/1/2/3')
    parser.add_argument('--drop_rate', type=float, default=0.2)
    parser.add_argument('--patience', type=int, default=100)
    parser.add_argument('--epoch', type=int, default=500)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--model', type=str, default='dhl')

    parser.add_argument('--edges', type=str, default='h')
    parser.add_argument('--mask', type=int, default=1)
    parser.add_argument('--cf', type=str, default='x')
    parser.add_argument('--merge', type=str, default='cat', help='cat/plus')
    parser.add_argument('--stage', type=str, default='train', help='train/val/test')

    parser.add_argument('--conv_number', type=int, default=1)
    parser.add_argument('--k_n', type=int, default=10, help='number of nodes to choose')
    parser.add_argument('--k_e', type=int, default=10, help='number of edges to choose')
    parser.add_argument('--k_m', type=int, default=20, help='number of edges to choose')
    parser.add_argument('--low_bound', type=float, default=0.9)
    parser.add_argument('--up_bound', type=float, default=0.95)

    parser.add_argument('--backbone', type=str, default='linear')
    parser.add_argument('--namuda', type=int, default=30)
    parser.add_argument('--namuda2', type=float, default=10)

    parser.add_argument('--splits', type=int, default=1)
    parser.add_argument('--fts', type=str, default='all', help='MVCNN/GVCNN')

    parser.add_argument('--split_ratio', type=float, default=0.8)
    parser.add_argument('--runs', type=int, default=1)
    parser.add_argument('--transfer', type=int, default=1)
    parser.add_argument('--use_feats', type=int, default=1)

    parser.add_argument('--hidden_channels', type=int, default=128)
    parser.add_argument('--out_channels', type=int, default=128)
    parser.add_argument('--local_layers', type=int, default=7,
                        help='number of layers for local attention')
    parser.add_argument('--global_layers', type=int, default=2,
                        help='number of layers for global attention')
    parser.add_argument('--num_heads', type=int, default=5,
                        help='number of heads for attention')
    parser.add_argument('--beta', type=float, default=-1.0,
                        help='Polynormer beta initialization')
    parser.add_argument('--pre_ln', action='store_true')

    parser.add_argument('--in_dropout', type=float, default=0.15)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--global_dropout', type=float, default=None)


    args = parser.parse_args()

    args.device = 'cuda:{}'.format(args.cuda) if torch.cuda.is_available() else 'cpu'

    return args



def mscatter(x, y,  ax=None, m=None, **kw):
    if not ax: ax = plt.gca()
    sc = ax.scatter(x, y, **kw)

    if (m is not None) and (len(m) == len(x)):
        paths = []
        for marker in m:
            if isinstance(marker, mmarkers.MarkerStyle):
                marker_obj = marker
            else:
                marker_obj = mmarkers.MarkerStyle(marker)
                
            path = marker_obj.get_path().transformed(
                marker_obj.get_transform())
            paths.append(path)
        sc.set_paths(paths)

    return sc

# def plot_embedding_2d(X, labels, fname, title=None):
#     config = {
#     "font.size": 20,
#     "mathtext.fontset":'stix'
# }
#
#     rcParams.update(config)
#
#
#     """Plot an embedding X with the class label y colored by the domain d."""
#     x_min, x_max = np.min(X, 0), np.max(X, 0)
#     X = (X - x_min) / (x_max - x_min)
#
#     # Plot colors numbers
#     num_of_labels = labels.max().item() + 1
#
#     # fig = plt.figure(figsize=(16,10))
#
#     plt.figure(figsize=(16,10))
#
#     # plt.margins(0.001)
#
#     colors_space = np.linspace(0, 1, num_of_labels)           # 生成颜色空间
#     label_to_color = {}                                 # 将标签对应为颜色
#     for i in range(num_of_labels):
#         label_to_color[i] = colors_space[i]
#
#     colors = []
#     for label in labels:
#         colors.append(label_to_color[label.item()])
#
#     sc = plt.scatter(X[:, 0], X[:, 1], c=colors, s=20)
#     # scatter = mscatter(X[:, 0], X[:, 1], c='r', ax=ax)
#
#     plt.xticks([]), plt.yticks([])
#     if title is not None:
#         plt.title(title)
#
#     cb=plt.colorbar(sc)
#     cb.ax.tick_params(labelsize=32)  #设置色标刻度字体大小。
#
#     plt.tight_layout(rect=(0, 0, 1.06, 1))
#
#     plt.savefig(f'./{fname}.eps')
#     plt.savefig(f'./{fname}.png')
#     # plt.show()



def plot_embedding_2d(X, labels, fname, title=None):
    import matplotlib.patches as mpatches
    config = {
        "font.size": 20,
        "mathtext.fontset": 'stix'
    }
    rcParams.update(config)

    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    num_of_labels = labels.max().item() + 1
    plt.figure(figsize=(16, 10))
    ax = plt.gca()

    # 颜色映射
    cmap = plt.cm.get_cmap("tab20", num_of_labels)
    colors = [cmap(i) for i in range(num_of_labels)]

    # 常用 marker 列表（够多用于组合）
    marker_list = ['o', 's', 'v', '^', '<', '>', 'P', '*', 'X', 'D', 'h', '8', '+', 'x', '|', '_', '1', '2', '3', '4']

    # 创建 label -> color & marker 映射
    label_to_color = {i: colors[i % len(colors)] for i in range(num_of_labels)}
    label_to_marker = {i: marker_list[i % len(marker_list)] for i in range(num_of_labels)}

    # 分类别绘图
    for label in range(num_of_labels):
        idx = (labels == label)
        ax.scatter(
            X[idx, 0], X[idx, 1],
            c=[label_to_color[label]],
            marker=label_to_marker[label],
            label=str(label),
            s=40,
            edgecolors='k',  # 添加黑边提高区分度
            linewidths=0.3
        )

    plt.xticks([]), plt.yticks([])
    if title is not None:
        plt.title(title)

    # 添加图例（图右下角）
    plt.legend(
        loc='lower right',
        title="Class",
        bbox_to_anchor=(1.02, 0),
        borderaxespad=0.1,
        fontsize=12,
        ncol=2,  # 可改成 3 或更多列以压缩空间
        title_fontsize=14
    )

    plt.tight_layout(rect=(0, 0, 1.1, 1))
    plt.savefig(f'./{fname}_with_legend.eps')
    plt.savefig(f'./{fname}_with_legend.png')
    # plt.show()

from sklearn.cluster import KMeans
from sklearn import metrics
def draw_TSNE(X, labels, fname, title=None):
    tsne2d = TSNE(n_components=2, init='pca', random_state=0)
    X_tsne_2d = tsne2d.fit_transform(X)
    plot_embedding_2d(X_tsne_2d, labels, fname, title)

# def visualization(model, data, args, title=None):
#     mask = data['train_idx']
#
#     out, x, H, H_raw ,edges= model(data,args)
#
#     # X = x.detach().to('cpu')
#     # labels = data['lbls'].detach().to('cpu')
#
#     X = x[mask].detach().to('cpu')
#     labels = data['lbls'][mask].detach().to('cpu')
#
#     fname = f'{args.model}_{args.dataset}_{args.fts}'
#     draw_TSNE(X, labels, fname, title=None)
#     Silhouette_score = metrics.silhouette_score(X, labels)
#     print("Silhouette_score is: ", Silhouette_score)


def visualization(model, data, args, title=None):
    mask = data['train_idx']  # 如果你想用全部节点，可以换成 ~mask 或直接去掉 mask
    out, x, H, H_raw, edges = model(data, args)
    print(x.shape)
    # 选取嵌入和标签
    X = x[mask].detach().cpu()
    labels_true = data['lbls'][mask].detach().cpu().numpy()

    # 可视化
    fname = f'{args.model}_{args.dataset}_{args.fts}'
    draw_TSNE(X, labels_true, fname, title=title)

    # 轮廓系数
    silhouette_score = metrics.silhouette_score(X, labels_true)
    print("Silhouette Score:", silhouette_score)

    # KMeans 聚类
    n_clusters = len(set(labels_true))  # 或者 args.nclass
    print("Number of clusters:", n_clusters)
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)
    labels_pred = kmeans.labels_

    # NMI 和 ARI
    nmi = metrics.normalized_mutual_info_score(labels_true, labels_pred)
    ari = metrics.adjusted_rand_score(labels_true, labels_pred)

    print("NMI:", nmi)
    print("ARI:", ari)