from torch_geometric.data import DataLoader
from gin import *
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn.functional as F
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
# import random forest
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, homogeneity_completeness_v_measure
import random
from torch_geometric.data import Data
from arguments import arg_parse
from sigl_tools import coords_prediction, train_graphon, get_graphon, graph2XY
from sigl_tools import align_graphs, universal_svd



device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')



def plot_tsne_comparison_3d(emb_3d, true_labels, cluster_labels, closest_points, DS=None, incorrect_idx=None):
    """
    Plots side-by-side 3D t-SNE visualizations: one with true labels, one with clustering assignments.
    Optionally highlights incorrect points in the left plot.

    emb_3d: (N, 3) array or tensor of embeddings
    true_labels, cluster_labels: (N,) array/tensor of ints
    closest_points: list of index arrays/lists to highlight on the right plot when DS is None
    incorrect_idx: indices of points to outline on the left plot
    """

    # Enable LaTeX rendering (same as your original)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Computer Modern"],
        "axes.labelsize": 14,
        "font.size": 14,
        "legend.fontsize": 10,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
    })

    # Convert to numpy if needed
    if isinstance(emb_3d, torch.Tensor):
        emb_3d = emb_3d.detach().cpu().numpy()
    if isinstance(true_labels, torch.Tensor):
        true_labels = true_labels.detach().cpu().numpy()
    if isinstance(cluster_labels, torch.Tensor):
        cluster_labels = cluster_labels.detach().cpu().numpy()
    if incorrect_idx is not None:
        incorrect_idx = np.asarray(incorrect_idx)

    # Create subplots with 3D projection
    fig = plt.figure(figsize=(14, 6))
    ax_true = fig.add_subplot(1, 2, 1, projection='3d')
    ax_clust = fig.add_subplot(1, 2, 2, projection='3d')

    # Left: True Labels
    for label in np.unique(true_labels):
        idx = (true_labels == label)
        ax_true.scatter(
            emb_3d[idx, 0], emb_3d[idx, 1], emb_3d[idx, 2],
            label=fr"{{Class {label}}}", s=30, alpha=0.85
        )

    # Highlight incorrect points if provided (outlined markers)
    if incorrect_idx is not None and incorrect_idx.size > 0:
        ax_true.scatter(
            emb_3d[incorrect_idx, 0], emb_3d[incorrect_idx, 1], emb_3d[incorrect_idx, 2],
            facecolors='none', edgecolors='black', s=60, linewidths=1.0, label=r"{Incorrect}"
        )

    ax_true.set_title(r"{Colored by True Labels}")
    ax_true.set_xlabel(r"{t-SNE-1}")
    ax_true.set_ylabel(r"{t-SNE-2}")
    ax_true.set_zlabel(r"{t-SNE-3}")
    ax_true.legend(title=r"{Class}")
    ax_true.grid(True)
    ax_true.view_init(elev=20, azim=35)  # tweak as you like

    # Right: Cluster Labels
    for cluster in np.unique(cluster_labels):
        idx = (cluster_labels == cluster)
        ax_clust.scatter(
            emb_3d[idx, 0], emb_3d[idx, 1], emb_3d[idx, 2],
            label=fr"{{Graphon {cluster+1}}}", s=30, alpha=0.85
        )

    # Overlay closest points (only when DS is None, matching your original)
    if DS is None:
        for closest_idxs in closest_points:
            closest_idxs = np.asarray(closest_idxs)
            if closest_idxs.size > 0:
                ax_clust.scatter(
                    emb_3d[closest_idxs, 0], emb_3d[closest_idxs, 1], emb_3d[closest_idxs, 2],
                    color='black', marker='o', s=36
                )

    ax_clust.set_title(r"{Colored by Cluster Assignments}")
    ax_clust.set_xlabel(r"{t-SNE-1}")
    ax_clust.set_ylabel(r"{t-SNE-2}")
    ax_clust.set_zlabel(r"{t-SNE-3}")
    # ax_clust.legend(title=r"{Graphon}")
    ax_clust.grid(True)
    ax_clust.view_init(elev=20, azim=-40)  # complementary angle to left

    plt.tight_layout()
    if DS is None:
        plt.savefig("Plots/emb3d_comparison_sim.jpg", dpi=300, bbox_inches='tight')
    else:
        plt.savefig(f"Plots/emb3d_comparison_{DS}.jpg", dpi=300, bbox_inches='tight')
    plt.show()


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


args = arg_parse()
emb_type = args.embType
DS = args.DS
setup_seed(args.seed)
kk = args.J if args.J is not None else 10 # default is 10

gnn_dim_hidden = [8, 8]
n_epochs_inr = 20
epoch_show = 10
inr_dim_hidden = [20, 20]
batch_size_inr = 1024
lr_inr = 0.01
w0 = 10
# load motif densities if available
try:
    all_embeddings_motif = np.load('Models/motif_densities_' + DS + '.npy')
    num_node_features = all_embeddings_motif.shape[1] if all_embeddings_motif.ndim > 1 else 1
except FileNotFoundError:
    print("No motif densities found for the dataset.")
    num_node_features = 1
dataset = TUDataset(root='data/TUDataset', name=DS)
num_classes_data = dataset.num_classes
dataset = list(dataset)
len_data = len(dataset)
nCluster = np.ceil(np.sqrt(len_data)).astype(int) if args.nCluster==-1 else args.nCluster
print(f"Number of clusters: {nCluster}")
print(f"Number of features per node: {num_node_features}")
all_graphs = dataset
labels = np.array([graph.y.item() for graph in all_graphs])
all_embeddings = all_embeddings_motif

# # load train_val_idx
# with open("Gmixup/Models/train_val_idx_" + DS + ".pkl", 'rb') as f:
#     train_val_idx = pickle.load(f)
train_val_idx = range(len_data)

assert all_embeddings_motif is not None, "Motif embeddings required for this step."
assert all_embeddings_motif.shape[0] == len(all_graphs) == labels.shape[0], "Counts must align."

# Normalize/standardize is optional; keeping behavior consistent with your original code.
emb = all_embeddings_motif
n_all = emb.shape[0]

# Coerce train_val_idx to an index array
train_val_idx = np.asarray(train_val_idx)
if train_val_idx.dtype == bool:
    tv_idx = np.where(train_val_idx)[0]
else:
    tv_idx = train_val_idx.astype(int)
tv_idx = np.asarray(tv_idx)
assert tv_idx.ndim == 1, "train_val_idx must be 1D indices or boolean mask."
print(len(tv_idx))

# Convenience: map dataset index -> position in train/val list
tv_pos = {int(i): pos for pos, i in enumerate(tv_idx)}

# Storage:
# - per (label, local_cluster_id): (trained_inr, coords_model, est_graphon)
cluster_models = {}
# - for each graph in train/val (aligned with tv_idx order): tuple (label_vec, est_graphon, trained_inr)
graphons_gmixup_trainval = [None] * len(tv_idx)

# Choose number of clusters per label subset.
def choose_k(n, user_k=args.nCluster if hasattr(args, "nCluster") else -1):
    if user_k is not None and user_k != -1:
        return int(max(1, min(user_k, n)))
    # default: ceil(sqrt(n)), at least 1
    return int(max(1, np.ceil(np.log(n))))

seed = getattr(args, "seed", 42)

per_graph_cc = {}
for c in range(num_classes_data):
    # indices in train/val with label == c
    idx_c = [i for i in tv_idx if labels[i] == c]
    if len(idx_c) == 0:
        continue

    sub_emb = emb[idx_c]
    print(sub_emb.shape)
    n_k = choose_k(len(idx_c))
    # n_k = 10

    # KMeans within this label subset
    km = KMeans(n_clusters=n_k, random_state=42)
    clabels_local = km.fit_predict(sub_emb)
    centers = km.cluster_centers_

    # For each local cluster within label c
    for j in range(n_k):
        cluster_member_mask = (clabels_local == j)
        member_idx_global = np.array(idx_c)[cluster_member_mask]
        if member_idx_global.size == 0:
            continue

        # pick K nearest to the center
        center = centers[j]
        cluster_emb = sub_emb[cluster_member_mask]
        dists = np.linalg.norm(cluster_emb - center, axis=1)
        order = np.argsort(dists)
        kk_j = int(min(kk, len(order)))
        chosen_global = member_idx_global[order[:kk_j]]

        for g_idx in member_idx_global:
            per_graph_cc[int(g_idx)] = (int(c), int(j))

        # Build list of adjacencies from chosen graphs
        graphs_c = []
        for g_idx in chosen_global:
            graph_data = all_graphs[g_idx]
            # dense adj (symmetric)
            adj = torch.zeros((graph_data.num_nodes, graph_data.num_nodes))
            ei = graph_data.edge_index
            adj[ei[0], ei[1]] = 1
            adj[ei[1], ei[0]] = 1  # symmetrize
            graphs_c.append(adj)

        # Learn coords + graphon
        model_ISGL, _ = coords_prediction(
            inr_dim_hidden, gnn_dim_hidden,
            int(2 * n_epochs_inr), epoch_show, w0, graphs_c, lr_inr
        )
        num_nodes_all = sum(g.shape[0] for g in graphs_c)
        print(num_nodes_all)
        X_all, y_all, w_all, _ = graph2XY(graphs_c, num_nodes_all, model_ISGL, sortDeg=False)

        trained_inr = train_graphon(
            inr_dim_hidden, w0, X_all, y_all, w_all,
            int(n_epochs_inr), epoch_show, lr_inr, batch_size_inr
        )
        est_graphon = get_graphon(1000, trained_inr, gpu=True)

        cluster_models[(c, j)] = (trained_inr, model_ISGL, est_graphon)

        # Assign the learned graphon to *all* graphs in this (label==c, cluster==j)
        for g_idx in member_idx_global:
            label_vec = np.zeros(num_classes_data, dtype=np.float32)
            label_vec[c] = 1.0
            pos = tv_pos[int(g_idx)]
            graphons_gmixup_trainval[pos] = (label_vec, est_graphon, trained_inr)

# Safety check: all filled
missing = [pos for pos, tpl in enumerate(graphons_gmixup_trainval) if tpl is None]
assert len(missing) == 0, f"Some train/val graphs were not assigned a graphon: positions {missing}"

# Save artifacts
with open(f"Models/label_cluster_models_{DS}_trainval.pkl", "wb") as f:
    # Save lightweight refs: avoid raw model objects if you prefer -> convert to state_dicts
    import pickle
    pickle.dump(
        {
            "tv_idx": tv_idx,
            "cluster_models_keys": list(cluster_models.keys()),
            # Storing models directly can be brittle; you may switch to state_dicts:
            "cluster_models": cluster_models,
        },
        f,
    )

with open(f"Gmixup/Models/graphons_gmixup_{DS}_trainval.pkl", "wb") as f:
    import pickle
    # List aligned with tv_idx order: [(label_vec, est_graphon, trained_inr), ...]
    pickle.dump(graphons_gmixup_trainval, f)

print(f"Prepared {len(graphons_gmixup_trainval)} tuples for train/val graphs.")



# Collect and order graphons per class
graphons_by_class = {}
for (label, local_cid), (_inr, _coords, est_graphon) in cluster_models.items():
    graphons_by_class.setdefault(int(label), []).append((int(local_cid), est_graphon))

# Sort clusters within each class by their local cluster id
for lbl in graphons_by_class:
    graphons_by_class[lbl].sort(key=lambda t: t[0])
    graphons_by_class[lbl] = [g for _, g in graphons_by_class[lbl]]

classes_sorted = sorted(graphons_by_class.keys())
n_rows = len(classes_sorted)
n_cols = max(len(graphons_by_class[lbl]) for lbl in classes_sorted)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))

# Make sure axes is 2D
if n_rows == 1 and n_cols == 1:
    axes = np.array([[axes]])
elif n_rows == 1:
    axes = axes[np.newaxis, :]
elif n_cols == 1:
    axes = axes[:, np.newaxis]

for r, lbl in enumerate(classes_sorted):
    row_graphons = graphons_by_class[lbl]
    for c in range(n_cols):
        ax = axes[r, c]
        if c < len(row_graphons):
            ax.imshow(row_graphons[c], cmap='viridis')
            ax.set_title(f"Class {lbl} – Cluster {c}")
        ax.axis('off')
    # Row label
    axes[r, 0].set_ylabel(f"Class {lbl}", rotation=90, fontsize=12)

plt.tight_layout()
plt.savefig(f"Plots/Estimated_graphons_{DS}_byclass.jpg", dpi=700)
plt.close(fig)

# import matplotlib.pyplot as plt
# import numpy as np
# import os

# os.makedirs("Plots", exist_ok=True)

# for r, lbl in enumerate(classes_sorted):
#     row_graphons = graphons_by_class[lbl]
#     for c, graphon in enumerate(row_graphons):
#         # Create a fresh figure for each graphon
#         fig, ax = plt.subplots(figsize=(5, 5))
#         ax.imshow(graphon, cmap='viridis')
#         ax.set_title(f"Class {lbl} – Cluster {c}")
#         ax.axis('off')

#         # Save each figure separately
#         plt.tight_layout()
#         plt.savefig(f"Plots/class/Estimated_graphon_{DS}_class{lbl}_cluster{c}.jpg", dpi=600)
#         plt.close(fig)




# Build embeddings and (class,cluster) labels in train/val order
emb_tv = emb[tv_idx]
cc_pairs = [per_graph_cc[int(i)] for i in tv_idx]  # list of (class, local_cluster)

# Map (class,cluster) pairs to integer codes for coloring + legend
pairs_sorted = sorted(set(cc_pairs))
pair2code = {p: k for k, p in enumerate(pairs_sorted)}
code2name = {k: f"class {c}, cluster {j}" for (c, j), k in pair2code.items()}
y_code = np.array([pair2code[p] for p in cc_pairs], dtype=int)

# Safe perplexity
n_tv = emb_tv.shape[0]
perp = max(5, min(30, n_tv // 3)) if n_tv > 10 else max(2, min(5, n_tv - 1))

tsne3 = TSNE(
    n_components=3,
    perplexity=min(perp, n_tv - 1),
    learning_rate="auto",
    n_iter=1000,
    init="pca",
    random_state=seed,
)
Z = tsne3.fit_transform(emb_tv)  # [n_tv, 3]

# Marker per class
unique_classes = sorted(set(c for c, _ in pairs_sorted))
marker_cycle = ['o', '^', 's', 'D', 'v', 'P', '*', 'X', '<', '>','h','H','8','p']
markers_by_class = {cls: marker_cycle[i % len(marker_cycle)] for i, cls in enumerate(unique_classes)}

# Color per (class,cluster)
cmap = plt.cm.get_cmap('tab20', len(pairs_sorted))
code2color = {code: cmap(code) for code in range(len(pairs_sorted))}

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")

# Plot each (class,cluster) with its color; shape set by class
for (c, j), code in pair2code.items():
    mask = (y_code == code)
    if not np.any(mask):
        continue
    ax.scatter(
        Z[mask, 0], Z[mask, 1], Z[mask, 2],
        s=14, alpha=0.9,
        marker=markers_by_class[c],
        color=code2color[code],
        label=f"class {c}, cluster {j}",
    )

ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")
ax.set_zlabel("t-SNE 3")
ax.set_title(f"3D t-SNE of motif embeddings (train/val) — class=marker, cluster=color — {DS}")

# Legend can get long; keep it if you need the exact mapping:
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), borderaxespad=0., fontsize=9)

plt.tight_layout()
plt.savefig(f"Plots/TSNE3D_{DS}.jpg", dpi=300, bbox_inches="tight")
plt.close(fig)