from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn.functional as F
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
# import random forest
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, homogeneity_completeness_v_measure
import random
from torch_geometric.data import Data
from sigl_tools import coords_prediction, train_graphon, get_graphon, graph2XY
from sigl_tools import align_graphs, universal_svd



device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')


def setup_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)


def build_graphons_for_trainval(train_val_idx, DS, seed, J, nCluster):
    """
    Build (label_onehot, est_graphon, trained_inr) tuples for the given indices.

    Inputs:
        train_val_idx: 1D index array/list or boolean mask over the dataset.

    Uses globals already defined in your script:
        emb (np.ndarray): motif embeddings, shape [N, D]
        labels (np.ndarray): graph labels, shape [N]
        all_graphs (list[PyG Data]): dataset graphs, len N
        num_classes_data (int)
        kk (int): how many nearest graphs per (label,cluster) to train the INR on
        args: expects args.nCluster and args.seed
        SIGL helpers: coords_prediction, graph2XY, train_graphon, get_graphon

    Returns:
        graphons_gmixup_trainval: list of length len(train_val_idx) in that exact order.
            Each item is (label_i_vector, graphon_i, trained_inr_i)
    """

    gnn_dim_hidden = [8, 8]
    n_epochs_inr = 20
    epoch_show = 10
    inr_dim_hidden = [20, 20]
    batch_size_inr = 1024
    lr_inr = 0.01
    w0 = 10


    # setup_seed(seed)
    kk = J if J is not None else 10 # default is 10

    try:
        all_embeddings_motif = np.load('Models/motif_densities_' + DS + '.npy')
        num_node_features = all_embeddings_motif.shape[1] if all_embeddings_motif.ndim > 1 else 1
        emb = all_embeddings_motif

    except FileNotFoundError:
        print("No motif densities found for the dataset.")
        num_node_features = 1
    dataset = TUDataset(root='data/TUDataset', name=DS)
    num_classes_data = dataset.num_classes
    dataset = list(dataset)
    print(f"Number of features per node: {num_node_features}")
    all_graphs = dataset
    labels = np.array([graph.y.item() for graph in all_graphs])

    # --- sanity checks
    assert emb.shape[0] == len(all_graphs) == labels.shape[0], "Counts must align."

    # coerce indices
    tv_idx = np.asarray(train_val_idx)
    if tv_idx.dtype == bool:
        tv_idx = np.where(tv_idx)[0]
    tv_idx = tv_idx.astype(int)
    assert tv_idx.ndim == 1, "train_val_idx must be 1D indices or boolean mask."

    # map global idx -> position in tv order (for output alignment)
    tv_pos = {int(i): pos for pos, i in enumerate(tv_idx)}

    # storage
    graphons_gmixup_trainval = [None] * len(tv_idx)

    # choose #clusters per label subset
    def choose_k(n, user_k=nCluster):
        if user_k is not None and user_k != -1:
            return int(max(1, min(user_k, n)))
        # your current fallback uses ceil(log n)
        return int(max(1, np.ceil(np.log(max(n, 2)))))  # guard n>=2

    seed = seed
    gpu_flag = torch.cuda.is_available()

    # per-(label,local_cluster) models (optional, in case you want to inspect later)
    # cluster_models = {}
    # per_graph_cc = {}

    for c in range(num_classes_data):
        # indices in tv set with label == c
        idx_c = [i for i in tv_idx if labels[i] == c]
        if len(idx_c) == 0:
            continue

        sub_emb = emb[idx_c]
        n_k = choose_k(len(idx_c))

        km = KMeans(n_clusters=n_k, random_state=seed)
        clabels_local = km.fit_predict(sub_emb)
        centers = km.cluster_centers_

        for j in range(n_k):
            member_mask = (clabels_local == j)
            member_idx_global = np.asarray(idx_c)[member_mask]
            if member_idx_global.size == 0:
                continue

            # choose up to kk nearest to center to TRAIN the INR/graphon
            center = centers[j]
            cluster_emb = sub_emb[member_mask]
            dists = np.linalg.norm(cluster_emb - center, axis=1)
            order = np.argsort(dists)
            kk_j = int(min(kk, len(order)))
            chosen_global = member_idx_global[order[:kk_j]]

            # build adjacencies
            graphs_c = []
            for g_idx in chosen_global:
                g = all_graphs[int(g_idx)]
                adj = torch.zeros((g.num_nodes, g.num_nodes))
                ei = g.edge_index
                adj[ei[0], ei[1]] = 1
                adj[ei[1], ei[0]] = 1  # symmetrize
                graphs_c.append(adj)

            # learn coords + graphon (reusing your hyperparams from the script)
            model_ISGL, _ = coords_prediction(
                inr_dim_hidden, gnn_dim_hidden,
                int(2 * n_epochs_inr), epoch_show, w0, graphs_c, lr_inr
            )

            num_nodes_all = sum(G.shape[0] for G in graphs_c)
            X_all, y_all, w_all, _ = graph2XY(graphs_c, num_nodes_all, model_ISGL, sortDeg=False)

            trained_inr = train_graphon(
                inr_dim_hidden, w0, X_all, y_all, w_all,
                int(n_epochs_inr), epoch_show, lr_inr, batch_size_inr
            )
            est_graphon = get_graphon(100, trained_inr, gpu=gpu_flag)

            # cluster_models[(c, j)] = (trained_inr, model_ISGL, est_graphon)

            # assign this cluster's graphon to ALL members of (label=c, cluster=j)
            for g_idx in member_idx_global:
                label_vec = np.zeros(num_classes_data, dtype=np.float32)
                label_vec[c] = 1.0
                pos = tv_pos[int(g_idx)]
                graphons_gmixup_trainval[pos] = (label_vec, est_graphon, trained_inr)
                # per_graph_cc[int(g_idx)] = (int(c), int(j))

    # ensure everything was filled
    missing = [i for i, v in enumerate(graphons_gmixup_trainval) if v is None]
    assert not missing, f"Unassigned graphs at positions: {missing}"

    return graphons_gmixup_trainval
