{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d5a94b8-49c3-4442-a5ae-8a8fc94f11ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove warning\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f0d2549-a4aa-4711-995e-c72a0a1d6201",
   "metadata": {},
   "outputs": [],
   "source": [
    "# previous code\n",
    "\n",
    "from sagman_utils.utils import spade,hnsw,construct_adj, spectral_embedding_eig,SPF,construct_weighted_adj,spade_nonetworkx\n",
    "\n",
    "import numpy as np\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "def to_unweighted_csr(adj_matrix_csr):\n",
    "    # Get the indices and indptr from the original matrix\n",
    "    indices = adj_matrix_csr.indices\n",
    "    indptr = adj_matrix_csr.indptr\n",
    "    # Create an array of 1s for the values (all edges have weight 1)\n",
    "    unweighted_data = np.ones(len(indices), dtype=np.float64)\n",
    "    # Create the unweighted adjacency matrix in CSR format\n",
    "    unweighted_adj_matrix = csr_matrix((unweighted_data, indices, indptr), shape=adj_matrix_csr.shape)\n",
    "\n",
    "    return unweighted_adj_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05132267-8812-46b8-91bd-9b2bfc648ef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# universal input number, can take it separately later\n",
    "\n",
    "input_num = 2000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49fa390-c694-4a6e-bf02-4314cfd225e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mnist data\n",
    "\n",
    "from keras.datasets import mnist\n",
    "(train_X, train_y), (test_X, test_y) = mnist.load_data()\n",
    "# print(test_X.shape)\n",
    "# input_num = 2000\n",
    "spec_embed_mnist = train_X[:input_num].reshape(input_num, 784)\n",
    "print(spec_embed_mnist.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96549362-d4d5-4732-a4ea-d96b9994d09e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fashion data\n",
    "\n",
    "from keras.datasets import fashion_mnist\n",
    "(train_X, train_y), (test_X, test_y) = fashion_mnist.load_data()\n",
    "# print(train_X.shape)\n",
    "# input_num = 2000\n",
    "spec_embed_fashion = train_X[:input_num].reshape(input_num, 784)\n",
    "print(spec_embed_fashion.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5725ce23-accd-426b-931c-61bd14e61268",
   "metadata": {},
   "outputs": [],
   "source": [
    "# kmnist data\n",
    "# downloaded it, then extracted it, reduced one dimension [60000, 28, 28, 1] to [60000, 28, 28]\n",
    "# matched the shaape of train_X and test_X of mnist and fashion\n",
    "# then resized to 784 shape\n",
    "\n",
    "import tensorflow_datasets as tfds\n",
    "dataset, info = tfds.load('kmnist', with_info=True, as_supervised=True)\n",
    "\n",
    "# Extract train and test datasets\n",
    "train_dataset = dataset['train']\n",
    "test_dataset = dataset['test']\n",
    "\n",
    "# Print dataset info\n",
    "# print(info)\n",
    "\n",
    "# Example: Iterating through the train dataset\n",
    "# for image, label in train_dataset.take(1):  # Just take the first example\n",
    "#     print(f\"Image shape: {image.shape}, Label: {label.numpy()}\")\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# Convert train dataset to NumPy arrays\n",
    "train_X = []\n",
    "train_y = []\n",
    "\n",
    "for image, label in train_dataset:\n",
    "    train_X.append(image.numpy())  # Convert tensor to numpy array\n",
    "    train_y.append(label.numpy())\n",
    "\n",
    "train_X = np.array(train_X)\n",
    "train_y = np.array(train_y)\n",
    "\n",
    "# Squeeze the shape of the train_X array from (60000, 28, 28, 1) to (60000, 28, 28)\n",
    "train_X = np.squeeze(train_X, axis=-1)\n",
    "\n",
    "# Similarly, for the test set:\n",
    "test_X = []\n",
    "test_y = []\n",
    "\n",
    "for image, label in test_dataset:\n",
    "    test_X.append(image.numpy())\n",
    "    test_y.append(label.numpy())\n",
    "\n",
    "test_X = np.array(test_X)\n",
    "test_y = np.array(test_y)\n",
    "\n",
    "test_X = np.squeeze(test_X, axis=-1)\n",
    "\n",
    "# Check the shapes\n",
    "# print(f\"Training data shape: {train_X.shape}\")\n",
    "# print(f\"Test data shape: {test_y.shape}\")\n",
    "\n",
    "# now at here, train_X and test_X and train test_y are equivalent to what we get in mnist and fashion\n",
    "\n",
    "spec_embed_kmnist = train_X[:input_num].reshape(input_num, 784)\n",
    "print(spec_embed_kmnist.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31d3e996-9dce-4936-9e41-aaee0c88b785",
   "metadata": {},
   "outputs": [],
   "source": [
    "import deeplake\n",
    "ds = deeplake.load('hub://activeloop/usps-test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aae3343a-132c-4f54-a5b8-c32813af0a21",
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "import numpy as np\n",
    "\n",
    "# Path to the offline USPS dataset file\n",
    "path = \"usps.h5\"  # Replace this with the correct path to your usps.h5 file\n",
    "\n",
    "# Load the USPS dataset from the .h5 file\n",
    "with h5py.File(path, 'r') as hf:\n",
    "    # Extract train data and labels\n",
    "    train = hf.get('train')\n",
    "    train_X = train.get('data')[:]  # Shape: (7291, 256)\n",
    "    train_y = train.get('target')[:]  # Shape: (7291,)\n",
    "    \n",
    "    # Extract test data and labels\n",
    "    test = hf.get('test')\n",
    "    test_X = test.get('data')[:]  # Shape: (2007, 256)\n",
    "    test_y = test.get('target')[:]  # Shape: (2007,)\n",
    "\n",
    "# Normalize the data to the range [-1, 1]\n",
    "train_X = train_X / 127.5 - 1\n",
    "test_X = test_X / 127.5 - 1\n",
    "\n",
    "# Reshape the data to (num_samples, 16, 16)\n",
    "train_X = train_X.reshape(-1, 16, 16)\n",
    "test_X = test_X.reshape(-1, 16, 16)\n",
    "\n",
    "# Convert labels to integers\n",
    "train_y = train_y.astype(int)\n",
    "test_y = test_y.astype(int)\n",
    "\n",
    "# Check the shape of the train and test datasets\n",
    "print(f\"Training data shape: {train_X.shape}\")  # Should be (7291, 16, 16)\n",
    "print(f\"Test data shape: {test_X.shape}\")      # Should be (2007, 16, 16)\n",
    "\n",
    "# Verify the first training sampl\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52a727cf-fafe-48ba-b80c-96df259a6204",
   "metadata": {},
   "outputs": [],
   "source": [
    "# USPS\n",
    "# downloaded it, extracted it\n",
    "# the output here is in [7291, 16, 16] sized which is normalized in range of [-1,1]\n",
    "\n",
    "from sklearn.datasets import fetch_openml\n",
    "import numpy as np\n",
    "\n",
    "print(\"1\")\n",
    "\n",
    "# Load USPS dataset from OpenML\n",
    "usps = fetch_openml('USPS', version=1)\n",
    "\n",
    "print(\"1\")\n",
    "# Extract the data and labels\n",
    "X = usps['data']\n",
    "y = usps['target']\n",
    "\n",
    "# Convert the data to numpy array and reshape it to (num_samples, 16, 16)\n",
    "X = X.to_numpy().reshape(-1, 16, 16)\n",
    "\n",
    "# Split into train and test sets\n",
    "# In the USPS dataset, the first 7291 samples are the training data and the rest are test data\n",
    "train_X = X[:7291]\n",
    "# print(train_X[0])\n",
    "test_X = X[7291:]\n",
    "\n",
    "# Convert labels to integers (use `int` instead of `np.int`)\n",
    "train_y = y[:7291].astype(int)\n",
    "test_y = y[7291:].astype(int)\n",
    "\n",
    "# Check the shape of the train and test datasets\n",
    "print(f\"Training data shape: {train_X.shape}\")  # Should be (7291, 16, 16)\n",
    "print(f\"Test data shape: {test_X.shape}\")      # Should be (2007, 16, 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0daa69c4-beb7-4670-b13b-a04b8ed1c577",
   "metadata": {},
   "outputs": [],
   "source": [
    "# USPS\n",
    "# [7291, 16, 16] data example plot, which is normalized in range [-1,1]\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Plot one image from the train_X dataset\n",
    "plt.imshow(train_X[0], cmap='gray')  # Use the first image in the dataset\n",
    "plt.title(f\"Label: {train_y[0]}\")  # Display the label of the first image\n",
    "plt.axis('off')  # Hide the axis\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7bac2c3-d26f-4fcd-ae1c-57dec971399e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# USPS\n",
    "# resized the data to match the shape of mnist, fashion, kmnist\n",
    "# resized [-, 16, 16] to [-, 28, 28]\n",
    "# Still normalized between [-1,1]\n",
    "\n",
    "import numpy as np\n",
    "from skimage.transform import resize\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Resize USPS images from 16x16 to 28x28 using skimage\n",
    "train_X_resized = np.array([resize(image, (28, 28), mode='reflect', anti_aliasing=True) for image in train_X])\n",
    "test_X_resized = np.array([resize(image, (28, 28), mode='reflect', anti_aliasing=True) for image in test_X])\n",
    "\n",
    "# Check the shape of the resized images\n",
    "print(f\"Resized training data shape: {train_X_resized.shape}\")  # Should be (7291, 28, 28)\n",
    "print(f\"Resized test data shape: {test_X_resized.shape}\")      # Should be (2007, 28, 28)\n",
    "\n",
    "# Plot one resized image from the train_X_resized dataset\n",
    "plt.imshow(train_X_resized[0], cmap='gray')\n",
    "plt.title(f\"Label: {train_y[0]}\")\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "899c64b5-e5a6-4671-977c-964ed5dcd2d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# USPS\n",
    "# De-normalized it from [-1,1] to [0, 255]\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Assuming train_X_resized and test_X_resized are the images between -1 and 1\n",
    "train_X_scaled = np.clip(((train_X_resized + 1) / 2) * 255, 0, 255).astype(np.uint8)\n",
    "test_X_scaled = np.clip(((test_X_resized + 1) / 2) * 255, 0, 255).astype(np.uint8)\n",
    "\n",
    "# Check the shape and the pixel value range\n",
    "print(f\"Scaled training data shape: {train_X_scaled.shape}\")  # Should be (7291, 28, 28)\n",
    "print(f\"Scaled test data shape: {test_X_scaled.shape}\")      # Should be (2007, 28, 28)\n",
    "\n",
    "# Check the pixel range for the first image\n",
    "print(f\"Pixel range of the first image: {train_X_scaled[0].min()} to {train_X_scaled[0].max()}\")\n",
    "\n",
    "# Plot one scaled image from the train_X_scaled dataset\n",
    "plt.imshow(train_X_scaled[0], cmap='gray')\n",
    "plt.title(f\"Label: {train_y[0]}\")\n",
    "plt.axis('off')\n",
    "plt.show()\n",
    "\n",
    "# Here, the train_X_scaled, test_X_scaled, train_y, test_y are equivalent to train_X and stuffs we get at mnist and fashion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38998eb1-069e-4887-b840-52bae3a8b83d",
   "metadata": {},
   "outputs": [],
   "source": [
    "spec_embed_usps = train_X_scaled[:input_num].reshape(input_num, 784)\n",
    "print(spec_embed_usps.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f7f94f-d762-498a-a548-2a15a2b5cd58",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(spec_embed_usps[0])\n",
    "# print(spec_embed_fashion[0])\n",
    "# print(spec_embed_kmnist[0])\n",
    "# print(spec_embed_mnist[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aadf5c4-0d3e-479b-8b52-e56f6ca5f7c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# modules for ggd value\n",
    "\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from scipy.stats import pearsonr\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e624b60-a530-4a08-9e65-78ce58f0fd45",
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for ggd\n",
    "\n",
    "def calculate_eigenvalues_sum(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    # Take log of eigenvalues, square them, and sum\n",
    "    eigenvalues_sum = np.sum(np.log(eigenvalues)**2)\n",
    "    \n",
    "    return eigenvalues_sum\n",
    "\n",
    "def calculate_eigenvalues(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    return eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c680365e-acc7-493b-85cc-e6b9e48c34cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cf0f679-6c5e-45ec-a50b-c74909c430d6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "568a6374-83f3-4267-a98d-d1203b1ed60d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# simple function that extracts the node and edge numbers from the dataset connected graph\n",
    "\n",
    "def graph_info(adj_matrix):\n",
    "    # Calculate the number of nodes\n",
    "    num_nodes = adj_matrix.shape[0]\n",
    "    \n",
    "    # Calculate the number of edges\n",
    "    # If the graph is undirected, divide by 2 to avoid double-counting edges\n",
    "    is_directed = not np.array_equal(adj_matrix, adj_matrix.T)\n",
    "    num_edges = np.sum(adj_matrix) // (1 if is_directed else 2)\n",
    "    \n",
    "    return num_nodes, num_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d885083b-c3a3-4841-bf66-f8992364d4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c45050-6e25-4b3f-95b0-35ed14911c99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ggd\n",
    "\n",
    "\n",
    "# Ensure plots are displayed inline in Jupyter Notebook\n",
    "%matplotlib inline\n",
    "\n",
    "k = 11 #k-nn parameter\n",
    "levels = 8 #Sparsification Level\n",
    "\n",
    "# settings for this loop\n",
    "the_k = k\n",
    "print(\"Value of K:\", the_k)\n",
    "print(\"Value of L:\", level)\n",
    "\n",
    "\n",
    "# mnist [1]\n",
    "\n",
    "# adj is in csr matrix format, features is numpy array \n",
    "# spec_embed = spectral_embedding_eig(adj,features,use_feature=True,adj_norm=False) # spectral_embedding_eig,spectral_embedding\n",
    "# this spec_embed is the a n*m matrix with n samples and m features\n",
    "neighs, distance = hnsw(spec_embed_mnist, k=the_k)\n",
    "embed_adj_mtx = construct_weighted_adj(neighs, distance)# construct_weighted_adj,construct_adj\n",
    "embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, level)\n",
    "# if you want the unweighted embed adj mtx, run the following\n",
    "embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)\n",
    "\n",
    "adj_matrix_1 = embed_adj_mtx.toarray()\n",
    "nodes_1, edges_1 = graph_info(adj_matrix_1)\n",
    "\n",
    "print(\"Number of nodes:\", nodes_1)\n",
    "# print(\"MNIST: Number of edges:\", edges_1)\n",
    "degree_matrix_1 = np.diag(np.sum(adj_matrix_1, axis=1))\n",
    "l1 = degree_matrix_1 - adj_matrix_1\n",
    "\n",
    "# kmnist [2]\n",
    "\n",
    "neighs, distance = hnsw(spec_embed_kmnist, k=the_k)\n",
    "embed_adj_mtx = construct_weighted_adj(neighs, distance)# construct_weighted_adj,construct_adj\n",
    "embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, level)\n",
    "# if you want the unweighted embed adj mtx, run the following\n",
    "embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)\n",
    "\n",
    "adj_matrix_2 = embed_adj_mtx.toarray()\n",
    "nodes_2, edges_2 = graph_info(adj_matrix_2)\n",
    "# print(\"Value of K:\", the_k)\n",
    "# print(\"Number of nodes:\", nodes)\n",
    "# print(\"KMNIST: Number of edges:\", edges_2)\n",
    "# print(\"Ratio of MNSIT/FashionMNIST\", edges1/edges2)\n",
    "degree_matrix_2 = np.diag(np.sum(adj_matrix_2, axis=1))\n",
    "l2 = degree_matrix_2 - adj_matrix_2\n",
    "\n",
    "\n",
    "\n",
    "# fashion [3]\n",
    "\n",
    "neighs, distance = hnsw(spec_embed_fashion, k=the_k)\n",
    "embed_adj_mtx = construct_weighted_adj(neighs, distance)# construct_weighted_adj,construct_adj\n",
    "embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, level)\n",
    "# if you want the unweighted embed adj mtx, run the following\n",
    "embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)\n",
    "\n",
    "adj_matrix_3 = embed_adj_mtx.toarray()\n",
    "nodes_3, edges_3 = graph_info(adj_matrix_3)\n",
    "# print(\"Value of K:\", the_k)\n",
    "# print(\"Number of nodes:\", nodes)\n",
    "# print(\"Fashion: Number of edges:\", edges_3)\n",
    "# print(\"Ratio of MNSIT/FashionMNIST\", edges1/edges2)\n",
    "degree_matrix_3 = np.diag(np.sum(adj_matrix_3, axis=1))\n",
    "l3 = degree_matrix_3 - adj_matrix_3\n",
    "\n",
    "\n",
    "# usps [4]\n",
    "\n",
    "neighs, distance = hnsw(spec_embed_usps, k=the_k)\n",
    "embed_adj_mtx = construct_weighted_adj(neighs, distance)# construct_weighted_adj,construct_adj\n",
    "embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, level)\n",
    "# if you want the unweighted embed adj mtx, run the following\n",
    "embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)\n",
    "\n",
    "adj_matrix_4 = embed_adj_mtx.toarray()\n",
    "nodes_4, edges_4 = graph_info(adj_matrix_4)\n",
    "# print(\"Value of K:\", the_k)\n",
    "# print(\"Number of nodes:\", nodes)\n",
    "# print(\"USPS: Number of edges:\", edges_4)\n",
    "# print(\"Ratio of MNSIT/FashionMNIST\", edges1/edges2)\n",
    "degree_matrix_4 = np.diag(np.sum(adj_matrix_4, axis=1))\n",
    "l4 = degree_matrix_4 - adj_matrix_4\n",
    "\n",
    "\n",
    "# add epsilon value to diagonal values of Laplacian\n",
    "\n",
    "shape = l1.shape\n",
    "a = 0.001  # You can change this to any value you want\n",
    "diagonal_matrix = np.diag([a] * min(shape))\n",
    "\n",
    "# as all the matrices are of same sized, we do not need to calculate diagonal matrix for each Laplacian separately\n",
    "nl1 = l1 + diagonal_matrix\n",
    "nl2 = l2 + diagonal_matrix\n",
    "nl3 = l3 + diagonal_matrix\n",
    "nl4 = l4 + diagonal_matrix\n",
    "\n",
    "\n",
    "ggd_mnist_kmnist = calculate_eigenvalues_sum(nl1, nl2)\n",
    "ggd_mnist_fashion = calculate_eigenvalues_sum(nl1, nl3)\n",
    "ggd_mnist_usps = calculate_eigenvalues_sum(nl1, nl4)\n",
    "ggd_kmnist_fashion = calculate_eigenvalues_sum(nl2, nl3)\n",
    "ggd_kmnist_usps = calculate_eigenvalues_sum(nl2, nl4)\n",
    "ggd_fashion_usps = calculate_eigenvalues_sum(nl3, nl4)\n",
    "\n",
    "print(\"Numbers of Edges: MNIST, KMNIST, Fashion, USPS: \", edges_1, edges_2, edges_3, edges_4)\n",
    "edgelist = [edges_1, edges_2, edges_3, edges_4]\n",
    "print(\"Normalized numbers of Edges: \", edgelist/max(edgelist))\n",
    "\n",
    "print(\"GGD: MNIST-KMNIST: \", ggd_mnist_kmnist)\n",
    "print(\"GGD: MNIST-Fashion: \", ggd_mnist_fashion)\n",
    "print(\"GGD: MNIST-USPS: \", ggd_mnist_usps)\n",
    "print(\"GGD: KMNIST-Fashion: \", ggd_kmnist_fashion)\n",
    "print(\"GGD: KMNIST-USPS: \", ggd_kmnist_usps)\n",
    "print(\"GGD: Fashion-USPS: \", ggd_fashion_usps)\n",
    "\n",
    "distances = {\n",
    "    ('MNIST', 'KMNIST'): ggd_mnist_kmnist,\n",
    "    ('MNIST', 'FashionMNIST'): ggd_mnist_fashion,\n",
    "    ('MNIST', 'USPS'): ggd_mnist_usps,\n",
    "    ('KMNIST', 'FashionMNIST'): ggd_kmnist_fashion,\n",
    "    ('KMNIST', 'USPS'): ggd_kmnist_usps,\n",
    "    ('FashionMNIST', 'USPS'): ggd_fashion_usps,\n",
    "}\n",
    "\n",
    "# Initialize the dataset names and the distance matrix\n",
    "datasets = ['MNIST', 'KMNIST', 'FashionMNIST', 'USPS']\n",
    "num_datasets = len(datasets)\n",
    "distance_matrix = np.zeros((num_datasets, num_datasets))\n",
    "\n",
    "# Fill the distance matrix with the calculated pairwise distances\n",
    "for (d1, d2), value in distances.items():\n",
    "    i, j = datasets.index(d1), datasets.index(d2)\n",
    "    distance_matrix[i, j] = value\n",
    "    distance_matrix[j, i] = value  # Symmetric matrix\n",
    "\n",
    "# Normalize the distance matrix by dividing by the maximum value\n",
    "max_distance = np.max(distance_matrix)\n",
    "normalized_distance_matrix = distance_matrix / max_distance\n",
    "\n",
    "# Create and display the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(normalized_distance_matrix, annot=True, fmt=\".3f\", xticklabels=datasets, yticklabels=datasets, cmap=\"coolwarm\")\n",
    "plt.title(\"Pairwise GGD Heatmap\")\n",
    "plt.xlabel(\"Datasets\")\n",
    "plt.ylabel(\"Datasets\")\n",
    "\n",
    "\n",
    "# Save the heatmap to a file\n",
    "plt.savefig(\"without_sparse_ggd_normalized_neigh_k_\"+ str(the_k)+\"_sparselevel_\" + str(level) + \".png\")\n",
    "plt.show()\n",
    "\n",
    "# Optionally close the plot to avoid overwriting in the next iterations\n",
    "plt.close()\n",
    "\n",
    "\n",
    "print(\"##################################\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b42dcae2-7827-4845-80b4-932ed89355cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to make graph to data structure like MUTAG - without node features\n",
    "# input, adjancency matrix, output Data with Data.x having [1.] for all noes\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import numpy as np\n",
    "\n",
    "def get_dataGraph(adj_matrix):\n",
    "    # Step 2: Extract row, col indices where adj_matrix is non-zero\n",
    "    row, col = np.where(adj_matrix == 1)\n",
    "    \n",
    "    # Step 3: Convert to edge_index format\n",
    "    edge_index = torch.tensor([row, col], dtype=torch.long)\n",
    "    \n",
    "    # Step 4: Create the node features tensor (size: [num_nodes, 1], default to 1 for all nodes)\n",
    "    num_nodes = adj_matrix.shape[0]\n",
    "    x = torch.ones((num_nodes, 1), dtype=torch.float)  # Placeholder node features (1 feature per node)\n",
    "    \n",
    "    # Step 5: Create the Data object\n",
    "    data = Data(edge_index=edge_index, x=x)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f128553-e234-4d04-926a-868a02c1b42f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "368fbd63-20e7-4d4a-a503-72d30d7ea8a4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "303401fc-e788-4b02-a9e3-336cda89dfd5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9 (torch)",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
