from sigl_tools import *
from torch_geometric.data import DataLoader
from gin import *
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, GCNConv, SAGEConv
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn.functional as F
import pickle
from arguments import arg_parse
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

def plot_tsne_comparison(emb_2d, true_labels, cluster_labels, closest_points, DS=None):
    """
    Plots side-by-side t-SNE visualizations: one with true labels, one with clustering assignments.

    Args:
        emb_2d (ndarray or Tensor): [N, 2] graph embeddings.
        true_labels (ndarray or list): [N,] true class labels.
        title (str): Overall title of the figure.
    """

    # Convert to numpy if needed
    if isinstance(emb_2d, torch.Tensor):
        emb_2d = emb_2d.detach().cpu().numpy()
    if isinstance(true_labels, torch.Tensor):
        true_labels = true_labels.detach().cpu().numpy()


    # Create subplots
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle("t-SNE Comparison: Labels vs Clusters", fontsize=16)

    # Left: True Labels
    for label in np.unique(true_labels):
        idx = true_labels == label
        axes[0].scatter(emb_2d[idx, 0], emb_2d[idx, 1], label=f"Class {label}", s=40, alpha=0.8)
    axes[0].set_title("Colored by True Labels")
    axes[0].set_xlabel("TSNE-1")
    axes[0].set_ylabel("TSNE-2")
    axes[0].legend(title="Class")
    axes[0].grid(True)

    # Right: Cluster Labels
    if isinstance(cluster_labels, torch.Tensor):
        cluster_labels = cluster_labels.detach().cpu().numpy()
    
    for cluster in np.unique(cluster_labels):
        idx = cluster_labels == cluster
        axes[1].scatter(emb_2d[idx, 0], emb_2d[idx, 1], label=f"Cluster {cluster}", s=40, alpha=0.8)
    if DS is None:
        for i, closest_idxs in enumerate(closest_points):
            axes[1].scatter(
                emb_2d[closest_idxs, 0], emb_2d[closest_idxs, 1],
                color='black', marker='o', s=40,)
    axes[1].set_title("Colored by Cluster Assignments")
    axes[1].set_xlabel("TSNE-1")
    axes[1].set_ylabel("TSNE-2")
    axes[1].legend(title="Cluster")
    axes[1].grid(True)

    plt.tight_layout()
    if DS is None:
        plt.savefig("Plots/Models/emb_comparison_sim.jpg")
    else:
        plt.savefig("Plots/Models/emb_comparison_" + DS + ".jpg")


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_dim, num_classes, num_layers):
        super(GCN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        for i in range(num_layers):
            if i == 0:
                self.convs.append(GCNConv(in_channels, hidden_dim))
            else:
                self.convs.append(GCNConv(hidden_dim, hidden_dim))

            self.bns.append(torch.nn.BatchNorm1d(hidden_dim))  # Batch normalization for stability

        self.embedding_dim = hidden_dim
        self.lin = Linear(hidden_dim, num_classes)  # Classification layer

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)

        x = global_mean_pool(x, batch)  # Graph-level embedding
        x = self.lin(x)  # Classification
        return x

    def get_graph_embeddings(self, data):
        """Extract graph-level embeddings before classification."""
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)

        graph_emb = global_mean_pool(x, batch)  # Graph-level representation
        graph_c = self.lin(graph_emb)
        graph_c = F.softmax(graph_c, dim=1)
        return graph_emb, graph_c


def convert_to_pyg_data(graphs, num_node_features, label):
    data_list = []
    for adj in graphs:
        edge_index = torch.tensor(np.array(np.where(adj > 0)), dtype=torch.long)  # Convert adjacency matrix to edge list
        num_nodes = adj.shape[0]
        
        # Generate random Gaussian noise as node features
        # x = torch.randn((num_nodes, num_node_features), dtype=torch.float)  # Assuming 5 features per node
        x = torch.ones((num_nodes, num_node_features), dtype=torch.float)  # Assuming 5 features per node
        
        # Create PyG Data object
        data = Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))
        data_list.append(data)
    
    return data_list

def add_noise(adj_list, noise_level):
    noisy_adj_list = []
    for adj in adj_list:
        noise = np.random.rand(*adj.shape) < noise_level
        noisy_adj = np.where(noise, 1 - adj, adj)
        noisy_adj_list.append(noisy_adj)
    return noisy_adj_list

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)

isSim = 0
num_node_features = 1
addNoise = 1 # 1 for adding noise to the graphs for simulated datasets

args = arg_parse()
seed = args.seed if args.seed is not None else random.randint(1, 10000)
# seed = 0

# for simulated datasets:
if isSim:
    graphon_types = [1, 8, 11, 14] # 1:np.exp(-(u ** 0.7 + v ** 0.7)) , 8: np.log(1 + 0.5 * np.maximum(u, v)), 11: SBM, 14: ER
    num_graphons = len(graphon_types)
    n_graph = 250
    offset = 150

    All_graphons_small = []
    for i in range(num_graphons):

        graphon = synthesize_graphon(r=1000, type_idx=graphon_types[i])
        graphon_small = synthesize_graphon(r=100, type_idx=graphon_types[i])
        All_graphons_small.append(graphon_small)
        graphs = simulate_graphs(w=graphon, graph_size='vary', num_graphs=n_graph, offset=offset)
        if addNoise:
            graphs = add_noise(graphs, noise_level=0.1)
        graphs = convert_to_pyg_data(graphs, num_node_features, i)
        if i == 0:
            all_graphs = graphs
        else:
            all_graphs += graphs
    # num_classes = num_graphons
    num_classes = np.ceil(np.log(len(all_graphs))).astype(int)
    print(f"Number of classes: {num_classes}")


else:
    DS = args.DS
    dataset = TUDataset(root='data/TUDataset', name=DS)
    num_data_classes = dataset.num_classes
    dataset = list(dataset)
    len_data = len(dataset)
    # num_classes = np.ceil(np.log(len_data)).astype(int)
    num_classes = args.nCluster if args.nCluster is not None else np.ceil(np.log(len_data)).astype(int)
    print(f"Number of classes: {num_classes}")
    # assign all ones as the data.x
    for graph in dataset:
        graph.x = torch.ones((graph.num_nodes, num_node_features), dtype=torch.float)
    
    all_graphs = dataset

try:
    all_embeddings_motif = np.load('data/MGCL/' + DS + '_motif_densities.npy')
    num_node_features = all_embeddings_motif.shape[1] if all_embeddings_motif.ndim > 1 else 1
except FileNotFoundError:
    print("No motif densities found for the dataset.")
    num_node_features = 1

# batch_size = 100
# hidden_dim=32
# lr=0.01
# num_gc_layers=3
# dataloader = DataLoader(all_graphs, batch_size=batch_size)

# model = GCN(in_channels=num_node_features, hidden_dim=hidden_dim, num_classes=num_classes, num_layers=num_gc_layers).to(device)
# print(model)
# embeddings = []
# labels = []
# model.eval()
# with torch.no_grad():
#     for data in dataloader:
#         data = data.to(device)
#         graph_emb, graph_class = model.get_graph_embeddings(data)
#         embeddings.append(graph_emb.cpu())
#         labels.extend(data.y.cpu().tolist())  # Collect labels
# all_embeddings = torch.cat(embeddings, dim=0).numpy()

# labels = [graph.y.item() for graph in all_graphs]
# labels = np.array(labels)

clustering_emb = all_embeddings_motif
kmeans = KMeans(n_clusters=num_classes, random_state=42)
cluster_labels = kmeans.fit_predict(clustering_emb)  # Assign clusters

# get the center of each cluster
centers = kmeans.cluster_centers_
kk = args.J if args.J is not None else 10 # default is 10
closest_points = []
for i, center in enumerate(centers):
    # Get indices of points in the current cluster
    cluster_indices = np.where(cluster_labels == i)[0]
    # Extract the embeddings of the points in this cluster
    clustering_emb_center = clustering_emb[cluster_indices]
    # Compute distances from center only for points in the cluster
    dists = np.linalg.norm(clustering_emb_center - center, axis=1)
    closest_idx_local = np.argsort(dists)[:kk]
    # Map back to original indices
    closest_idx = cluster_indices[closest_idx_local]
    print(f"Number of closest points to center {i}: {len(closest_idx)}")
    print(closest_idx)
    closest_points.append(closest_idx)

# all_embeddings_2d = TSNE(n_components=2, random_state=42).fit_transform(all_embeddings)
# if isSim:
#     plot_tsne_comparison(all_embeddings_2d, labels, cluster_labels, closest_points)
# else:
#     plot_tsne_comparison(all_embeddings_2d, labels, cluster_labels, closest_points, DS)

setup_seed(seed)

gnn_dim_hidden = [8, 8]
n_epochs_inr = 20
epoch_show = 10
inr_dim_hidden = [20, 20]
batch_size_inr = 1024
lr_inr = 0.01
w0 = 10
graphons_est = []
graphon_models = []
coords_models = []
for i in range(num_classes):
    # get the graphs of the closest points to the center of the cluster
    graphs_c = []
    for idx in closest_points[i]:
        graph_data_idx = all_graphs[idx]
        adj_idx = torch.zeros((graph_data_idx.num_nodes, graph_data_idx.num_nodes))
        edge_index = graph_data_idx.edge_index
        adj_idx[edge_index[0], edge_index[1]] = 1
        graphs_c.append(adj_idx)
    print("Number of graphs to estimate graphon: ", len(graphs_c))
    model_ISGL, _ = coords_prediction(inr_dim_hidden, gnn_dim_hidden, int(2*n_epochs_inr), epoch_show, w0, graphs_c, lr_inr)
    coords_models.append(model_ISGL)
    num_nodes_all = sum([graph_i.shape[0] for graph_i in graphs_c])
    X_all, y_all, w_all, _ = graph2XY(graphs_c, num_nodes_all, model_ISGL, sortDeg=False)
    trained_inr_c = train_graphon(inr_dim_hidden, w0, X_all, y_all, w_all, int(n_epochs_inr), epoch_show, lr_inr, batch_size_inr)
    graphon_models.append(trained_inr_c)
    graphon_c = get_graphon(100, trained_inr_c, gpu=True)
    graphons_est.append(graphon_c)



# save the cluster index of each graph in the dataset along with the list of graphons_est
if isSim:
    with open("data/MGCL/cluster_labels_sim.pkl", 'wb') as f:
        pickle.dump([all_graphs, cluster_labels, graphon_models, coords_models], f)
else:
    with open("data/MGCL/cluster_labels_" + DS + ".pkl", 'wb') as f:
        pickle.dump([cluster_labels, graphon_models, coords_models], f)


# # get loss between true graphons and estimated graphons
# if isSim:
#     plt.figure(figsize=(8, int(3*num_classes)))
# else:
#     plt.figure(figsize=(int(3*num_classes), 3))
# gw_loss_s = []
# for i in range(num_classes):
#     est_graphon = graphons_est[i]

#     if isSim:
#         gw_loss_temp = [gw_distance(true_graphon_temp, est_graphon) for true_graphon_temp in All_graphons_small]
#         gw_loss = min(gw_loss_temp)
#         true_graphon = All_graphons_small[np.argmin(gw_loss_temp)]
#         gw_loss_s.append(gw_loss)
#         print(f"GW Distance between true and estimated graphon {i}: {gw_loss}")
#         # plot true and estimated graphons in each row of the figure
#         plt.subplot(num_classes, 2, 2*i+1)
#         plt.imshow(true_graphon, cmap='viridis')
#         plt.title(f"True Graphon {i}")
#         plt.axis('off')

#         plt.subplot(num_classes, 2, 2*i+2)
#         plt.imshow(est_graphon, cmap='viridis')
#         plt.title(f"Estimated Graphon {i}, Gw Loss: {gw_loss:.3f}")
#         plt.axis('off')
#     else:
#         # just plot the estimated graphons
#         plt.subplot(1, num_classes, i+1)
#         plt.imshow(est_graphon, cmap='viridis')
#         plt.title(f"Estimated Graphon {i}")
#         plt.axis('off')


# plt.tight_layout()
# if isSim:
#     plt.savefig("Plots/Models/Estimated_graphons_sim.jpg")
# else:
#     plt.savefig("Plots/Models/Estimated_graphons_" + DS + ".jpg")


# # calculate the gw between pair of estimated graphons:
# gw_loss_pair = np.zeros((num_classes, num_classes))
# for i in range(num_classes):
#     try:
#         gw_loss_pair[i, i] = gw_loss_s[i]
#     except:
#         gw_loss_pair[i, i] = 0
#     for j in range(i+1, num_classes):
#         gw_loss_pair[i, j] = gw_distance(graphons_est[i], graphons_est[j])
#         gw_loss_pair[j, i] = gw_loss_pair[i, j]
        
# # plot the gw_loss_pair and save the matrix plot
# plt.figure(figsize=(int(num_classes), int(num_classes)))
# plt.imshow(gw_loss_pair, cmap='viridis')
# # add the numbers to the cells
# for i in range(num_classes):
#     for j in range(num_classes):
#         plt.text(j, i, f"{gw_loss_pair[i, j]:.3f}", ha='center', va='center', color='black')
# plt.title("GW Loss Pair")
# plt.xlabel("Estimated Graphons")
# plt.ylabel("Estimated Graphons")
# plt.colorbar()
# plt.tight_layout()
# if isSim:
#     plt.savefig("Plots/Models/gw_loss_pair_sim.jpg")
# else:
#     plt.savefig("Plots/Models/gw_loss_pair_" + DS + ".jpg")




# print("Graphon clusters and estimated graphons saved successfully")