from torch_geometric.data import DataLoader
from gin import *
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn.functional as F
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
# import random forest
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, homogeneity_completeness_v_measure
import random
from torch_geometric.data import Data
from arguments import arg_parse
from sigl_tools import coords_prediction, train_graphon, get_graphon, graph2XY
from sigl_tools import align_graphs, universal_svd



device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')



def plot_tsne_comparison_3d(emb_3d, true_labels, cluster_labels, closest_points, DS=None, incorrect_idx=None):
    """
    Plots side-by-side 3D t-SNE visualizations: one with true labels, one with clustering assignments.
    Optionally highlights incorrect points in the left plot.

    emb_3d: (N, 3) array or tensor of embeddings
    true_labels, cluster_labels: (N,) array/tensor of ints
    closest_points: list of index arrays/lists to highlight on the right plot when DS is None
    incorrect_idx: indices of points to outline on the left plot
    """

    # Enable LaTeX rendering (same as your original)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Computer Modern"],
        "axes.labelsize": 14,
        "font.size": 14,
        "legend.fontsize": 10,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
    })

    # Convert to numpy if needed
    if isinstance(emb_3d, torch.Tensor):
        emb_3d = emb_3d.detach().cpu().numpy()
    if isinstance(true_labels, torch.Tensor):
        true_labels = true_labels.detach().cpu().numpy()
    if isinstance(cluster_labels, torch.Tensor):
        cluster_labels = cluster_labels.detach().cpu().numpy()
    if incorrect_idx is not None:
        incorrect_idx = np.asarray(incorrect_idx)

    # Create subplots with 3D projection
    fig = plt.figure(figsize=(14, 6))
    ax_true = fig.add_subplot(1, 2, 1, projection='3d')
    ax_clust = fig.add_subplot(1, 2, 2, projection='3d')

    # Left: True Labels
    for label in np.unique(true_labels):
        idx = (true_labels == label)
        ax_true.scatter(
            emb_3d[idx, 0], emb_3d[idx, 1], emb_3d[idx, 2],
            label=fr"{{Class {label}}}", s=30, alpha=0.85
        )

    # Highlight incorrect points if provided (outlined markers)
    if incorrect_idx is not None and incorrect_idx.size > 0:
        ax_true.scatter(
            emb_3d[incorrect_idx, 0], emb_3d[incorrect_idx, 1], emb_3d[incorrect_idx, 2],
            facecolors='none', edgecolors='black', s=60, linewidths=1.0, label=r"{Incorrect}"
        )

    ax_true.set_title(r"{Colored by True Labels}")
    ax_true.set_xlabel(r"{t-SNE-1}")
    ax_true.set_ylabel(r"{t-SNE-2}")
    ax_true.set_zlabel(r"{t-SNE-3}")
    ax_true.legend(title=r"{Class}")
    ax_true.grid(True)
    ax_true.view_init(elev=20, azim=35)  # tweak as you like

    # Right: Cluster Labels
    for cluster in np.unique(cluster_labels):
        idx = (cluster_labels == cluster)
        ax_clust.scatter(
            emb_3d[idx, 0], emb_3d[idx, 1], emb_3d[idx, 2],
            label=fr"{{Graphon {cluster+1}}}", s=30, alpha=0.85
        )

    # Overlay closest points (only when DS is None, matching your original)
    if DS is None:
        for closest_idxs in closest_points:
            closest_idxs = np.asarray(closest_idxs)
            if closest_idxs.size > 0:
                ax_clust.scatter(
                    emb_3d[closest_idxs, 0], emb_3d[closest_idxs, 1], emb_3d[closest_idxs, 2],
                    color='black', marker='o', s=36
                )

    ax_clust.set_title(r"{Colored by Cluster Assignments}")
    ax_clust.set_xlabel(r"{t-SNE-1}")
    ax_clust.set_ylabel(r"{t-SNE-2}")
    ax_clust.set_zlabel(r"{t-SNE-3}")
    # ax_clust.legend(title=r"{Graphon}")
    ax_clust.grid(True)
    ax_clust.view_init(elev=20, azim=-40)  # complementary angle to left

    plt.tight_layout()
    if DS is None:
        plt.savefig("Plots/emb3d_comparison_sim.jpg", dpi=300, bbox_inches='tight')
    else:
        plt.savefig(f"Plots/emb3d_comparison_{DS}.jpg", dpi=300, bbox_inches='tight')
    plt.show()


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


args = arg_parse()
emb_type = args.embType
DS = args.DS
setup_seed(args.seed)

# load motif densities if available
try:
    all_embeddings_motif = np.load('Models/motif_densities_' + DS + '.npy')
    num_node_features = all_embeddings_motif.shape[1] if all_embeddings_motif.ndim > 1 else 1
except FileNotFoundError:
    print("No motif densities found for the dataset.")
    num_node_features = 1



dataset = TUDataset(root='data/TUDataset', name=DS)
num_classes_data = dataset.num_classes
dataset = list(dataset)
len_data = len(dataset)
nCluster = np.ceil(np.log(len_data)).astype(int) if args.nCluster==-1 else args.nCluster
print(f"Number of clusters: {nCluster}")
print(f"Number of features per node: {num_node_features}")
all_graphs = dataset
labels = np.array([graph.y.item() for graph in all_graphs])
all_embeddings = all_embeddings_motif



print("size all_embeddings: ", all_embeddings.shape)
kmeans = KMeans(n_clusters=nCluster, random_state=42)
cluster_labels = kmeans.fit_predict(all_embeddings)  # Assign clusters

ari = adjusted_rand_score(labels, cluster_labels)
nmi = normalized_mutual_info_score(labels, cluster_labels)
homogeneity, completeness, v_measure = homogeneity_completeness_v_measure(labels, cluster_labels)
print(f"Adjusted Rand Index (ARI): {ari:.4f}")
print(f"Normalized Mutual Information (NMI): {nmi:.4f}")
print(f"Homogeneity: {homogeneity:.4f}, Completeness: {completeness:.4f}, V-measure: {v_measure:.4f}")

# get the center of each cluster
centers = kmeans.cluster_centers_
kk = args.J if args.J is not None else 10 # default is 10
closest_points = []
for i, center in enumerate(centers):
    # Get indices of points in the current cluster
    cluster_indices = np.where(cluster_labels == i)[0]
    # Extract the embeddings of the points in this cluster
    clustering_emb_center = all_embeddings[cluster_indices]
    # Compute distances from center only for points in the cluster
    dists = np.linalg.norm(clustering_emb_center - center, axis=1)
    closest_idx_local = np.argsort(dists)[:kk]
    # Map back to original indices
    closest_idx = cluster_indices[closest_idx_local]
    print(f"Number of closest points to center {i}: {len(closest_idx)}")
    print(closest_idx)
    closest_points.append(closest_idx)


# all_embeddings_2d = TSNE(n_components=2, random_state=42).fit_transform(all_embeddings)
all_embeddings_2d = TSNE(n_components=3,  perplexity=min(30, all_embeddings.shape[0] - 1),
                learning_rate='auto', n_iter=1000, init='pca', random_state=42).fit_transform(all_embeddings)

plot_tsne_comparison_3d(all_embeddings_2d, labels, cluster_labels, closest_points, DS, incorrect_idx=None)

# Estimate the graphon for each cluster using SIGL:
gnn_dim_hidden = [8, 8]
n_epochs_inr = 20
epoch_show = 10
inr_dim_hidden = [20, 20]
batch_size_inr = 1024
lr_inr = 0.01
w0 = 10
graphons_est = []
graphon_models = []
coords_models = []
for i in range(nCluster):
    # get the graphs of the closest points to the center of the cluster
    graphs_c = []
    for idx in closest_points[i]:
        graph_data_idx = all_graphs[idx]
        adj_idx = torch.zeros((graph_data_idx.num_nodes, graph_data_idx.num_nodes))
        edge_index = graph_data_idx.edge_index
        adj_idx[edge_index[0], edge_index[1]] = 1
        graphs_c.append(adj_idx)
    print("Number of graphs to estimate graphon: ", len(graphs_c))
    model_ISGL, _ = coords_prediction(inr_dim_hidden, gnn_dim_hidden, int(2*n_epochs_inr), epoch_show, w0, graphs_c, lr_inr)
    coords_models.append(model_ISGL)
    num_nodes_all = sum([graph_i.shape[0] for graph_i in graphs_c])
    X_all, y_all, w_all, _ = graph2XY(graphs_c, num_nodes_all, model_ISGL, sortDeg=False)
    trained_inr_c = train_graphon(inr_dim_hidden, w0, X_all, y_all, w_all, int(n_epochs_inr), epoch_show, lr_inr, batch_size_inr)
    graphon_models.append(trained_inr_c)
    graphon_c = get_graphon(1000, trained_inr_c, gpu=True)
    graphons_est.append(graphon_c)
# save the cluster index of each graph in the dataset along with the list of graphons_est
with open("Models/cluster_labels_" + DS + ".pkl", 'wb') as f:
        pickle.dump([cluster_labels, graphon_models, coords_models], f)

# save data for Gmixup:
graphons_gmixup = []
for i in range(len_data):
    label_i = all_graphs[i].y.item()
    label_i_vector = np.zeros(num_classes_data)
    label_i_vector[label_i] = 1
    cluster_i = cluster_labels[i]
    graphon_i = graphons_est[cluster_i]
    trained_inr_i = graphon_models[cluster_i]
    graphons_gmixup.append((label_i_vector, graphon_i, trained_inr_i))
with open("Gmixup/Models/graphons_ gmixup_" + DS + ".pkl", 'wb') as f:
    pickle.dump(graphons_gmixup, f)
    

# get loss between true graphons and estimated graphons
plt.figure(figsize=(int(5*nCluster), 5))
gw_loss_s = []
for i in range(nCluster):
    est_graphon = graphons_est[i]
# just plot the estimated graphons
    plt.subplot(1, nCluster, i+1)
    plt.imshow(est_graphon, cmap='viridis')
    plt.title(f"Estimated Graphon {i+1}")
    plt.axis('off')
plt.tight_layout()
plt.savefig("Plots/Estimated_graphons_" + DS + ".jpg", dpi=700)
    
# gw_loss_s = []
# for i in range(nCluster):
#     est_graphon = graphons_est[i]
    
#     # Create a new figure for each graphon
#     plt.figure(figsize=(5, 5))
#     plt.imshow(est_graphon, cmap='viridis')
#     plt.title(f"Estimated Graphon {i+1}")
#     plt.axis('off')
#     plt.tight_layout()
    
#     # Save each subplot as its own file
#     plt.savefig(f"Plots/one/Estimated_graphon_{DS}_{i+1}.jpg")
#     plt.close()  # Close the figure to free memory


# apply logistic reggression on embedding_motif
if all_embeddings_motif is not None:
    print("Using motif densities for classification.")
    clf = LogisticRegression(random_state=42)
    clf.fit(all_embeddings_motif, labels)

    # Evaluate the model
    acc = clf.score(all_embeddings_motif, labels)
    print(f"Logistic Regression Accuracy on Motif Densities: {acc:.4f}")
    # get feature importance 
    feature_importances = clf.coef_
    print("Feature importances:", np.abs(feature_importances).flatten())


# # get pairwise mse distance of graphs
# all_graph_adj = []
# for idx in range(len_data):
#     graph_i = all_graphs[idx]
#     adj_i = torch.zeros((graph_i.num_nodes, graph_i.num_nodes))
#     edge_index = graph_i.edge_index
#     adj_i[edge_index[0], edge_index[1]] = 1
#     adj_i = adj_i.numpy()
#     degs = adj_i.sum(axis=0)
#     # sorted_indices = np.argsort(degs)[::-1]  # Sort in descending order
#     # adj_i = adj_i[sorted_indices, :][:, sorted_indices]  # Reorder the adjacency matrix
#     all_graph_adj.append(adj_i)

# resolution = 100  # resolution for the graphon
# align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(all_graph_adj, padding=True, N=resolution)
# graphons = [universal_svd([align_graphs_list[i]], threshold=0.2) for i in range(len(align_graphs_list))]

# # calculate the pairwise mse distanc eof graphons
# distances = np.zeros((len(graphons), len(graphons)))
# for i in range(len(graphons)):
#     for j in range(i + 1, len(graphons)):
#         dist = np.mean((graphons[i] - graphons[j]) ** 2)
#         distances[i, j] = dist
#         distances[j, i] = dist

# # # save the distances with pickle
# # with open(f"Models/graphon_distances_{DS}.pkl", 'wb') as f:
# #     pickle.dump(distances, f)

# # # plot the distances as a heatmap
# # plt.figure(figsize=(8, 6))
# # plt.imshow(distances, cmap='viridis', interpolation='nearest')
# # plt.colorbar(label='MSE Distance')
# # plt.title(f"Pairwise MSE Distances of Graphons ({DS})")
# # plt.xlabel("Graphons")
# # plt.ylabel("Graphons")
# # plt.tight_layout()
# # plt.savefig(f"Plots/graphon_distances_{DS}.jpg", dpi=300)




# # # --- clustering & visualization ---
# # n = distances.shape[0]
# # # Hierarchical linkage on your distance matrix
# # condensed = squareform(distances, checks=False)
# # Z = linkage(condensed, method='average')  # 'average' works well with precomputed dists

# # # Reorder heatmap by dendrogram leaf order
# # leaf_order = leaves_list(Z)
# # distances_reordered = distances[leaf_order][:, leaf_order]


# # # Plots: dendrogram and clustered heatmap
# # plt.figure(figsize=(10, 5))
# # dendrogram(Z, color_threshold=None)
# # plt.title(f"Hierarchical Clustering Dendrogram ({DS})")
# # plt.xlabel("Graphon index")
# # plt.ylabel("Linkage distance")
# # plt.tight_layout()
# # plt.savefig(f"Plots/graphon_dendrogram_{DS}.jpg", dpi=300)

# # # Plot the clustered heatmap
# # plt.figure(figsize=(8, 6))
# # im = plt.imshow(distances_reordered, cmap='viridis', interpolation='nearest')
# # plt.colorbar(im, label='MSE Distance')
# # plt.title(f"Pairwise MSE Distances of Graphons (labeled)")
# # plt.xlabel("Graphons (reordered)")
# # plt.ylabel("Graphons (reordered)")
# # plt.tight_layout()
# # plt.savefig(f"Plots/graphon_distances_clustered_{DS}.jpg", dpi=300)

# # 7) 2D embedding for a quick scatter view colored by labels
# mds = MDS(n_components=3, dissimilarity='precomputed', random_state=0)
# X3 = mds.fit_transform(distances)  # shape (n, 3)
# fig = plt.figure(figsize=(8, 6))
# ax = fig.add_subplot(111, projection='3d')
# sc = ax.scatter(X3[:, 0], X3[:, 1], X3[:, 2], c=labels, s=40, depthshade=True)
# handles, legend_labels = sc.legend_elements()
# ax.legend(handles, legend_labels, title="Labels", loc="best")

# ax.set_title(f"Graphons (3D MDS from distances)")
# ax.set_xlabel("MDS-1")
# ax.set_ylabel("MDS-2")
# ax.set_zlabel("MDS-3")
# plt.tight_layout()
# plt.savefig(f"Plots/graphon_mds3d_{DS}.jpg", dpi=300)
# plt.show()