{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "373d4936",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import time\n",
    "import numpy as np\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sgw_numpy import sgw_cpu\n",
    "import torch\n",
    "from risgw import risgw_gpu\n",
    "from SEINT_numpy import SEINT\n",
    "import ot\n",
    "\n",
    "with open('modelnet40_rotated.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "    train_3D_small = data['train_3D_small']\n",
    "    train_label_small = data['train_label_small']\n",
    "\n",
    "K = 5\n",
    "n_neighbors = 5\n",
    "seed = 42\n",
    "N = len(train_label_small)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc812c21",
   "metadata": {},
   "source": [
    "### RISGW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dce84854",
   "metadata": {},
   "outputs": [],
   "source": [
    "risgw_rep = 50             \n",
    "t_all = time.time()\n",
    "\n",
    "def compute_pair(i, j):\n",
    "    A_t = torch.from_numpy(np.asarray(train_3D_small[i], dtype=np.float32))\n",
    "    B_t = torch.from_numpy(np.asarray(train_3D_small[j], dtype=np.float32))\n",
    "    dist = risgw_gpu(A_t, B_t,'cpu', risgw_rep)\n",
    "    return i, j, float(dist)\n",
    "\n",
    "\n",
    "pairs = [(i, j) for i in range(N) for j in range(i + 1, N)]\n",
    "\n",
    "print(f\"[INFO] Start precomputing full {N}x{N} RISGW distance matrix with {len(pairs)} pairs ...\")\n",
    "t0 = time.time()\n",
    "results = Parallel(n_jobs=-1, backend='loky', verbose=1)(\n",
    "    delayed(compute_pair)(i, j) for i, j in pairs\n",
    ")\n",
    "\n",
    "risgw_dist_full = np.zeros((N, N), dtype=np.float64)\n",
    "for i, j, dist in results:\n",
    "    risgw_dist_full[i, j] = dist\n",
    "    risgw_dist_full[j, i] = dist\n",
    "print(f\"[INFO] Done precomputing. Time: {time.time() - t0:.2f}s\")\n",
    "\n",
    "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
    "\n",
    "acc_list = []\n",
    "fold_times = []\n",
    "\n",
    "for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(N)), start=1):\n",
    "    t_fold = time.time()\n",
    "    print(f\"\\n===== Fold {fold}/{K} =====\")\n",
    "\n",
    "    y_train = train_label_small[train_idx]\n",
    "    y_val = train_label_small[val_idx]\n",
    "\n",
    "    D_train = risgw_dist_full[np.ix_(train_idx, train_idx)]\n",
    "    D_val = risgw_dist_full[np.ix_(val_idx, train_idx)]\n",
    "\n",
    "    clf = KNeighborsClassifier(metric='precomputed', n_neighbors=n_neighbors)\n",
    "    clf.fit(D_train, y_train)\n",
    "\n",
    "    y_pred = clf.predict(D_val)\n",
    "    acc = (y_pred == y_val).mean() * 100\n",
    "    acc_list.append(acc)\n",
    "\n",
    "    fold_time = time.time() - t_fold\n",
    "    fold_times.append(fold_time)\n",
    "\n",
    "    print(f\"Fold {fold} Val Accuracy: {acc:.2f}% | Time: {fold_time:.2f}s\")\n",
    "\n",
    "print(\"\\n===== Final CV Result =====\")\n",
    "print(f\"Avg Acc: {np.mean(acc_list):.2f}% ± {np.std(acc_list):.2f}%\")\n",
    "print(f\"Avg Fold Time: {np.mean(fold_times):.2f}s | Total Time: {time.time() - t_all:.2f}s\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af09133d",
   "metadata": {},
   "source": [
    "### SGW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d9736d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sgw_rep = 50  \n",
    "\n",
    "t_all = time.time()\n",
    "\n",
    "def compute_pair(i, j):\n",
    "    dist = sgw_cpu(train_3D_small[i], train_3D_small[j], sgw_rep)\n",
    "    return i, j, dist\n",
    "\n",
    "\n",
    "pairs = [(i, j) for i in range(N) for j in range(i + 1, N)]\n",
    "\n",
    "print(f\"[INFO] Start precomputing full {N}x{N} SGW distance matrix with {len(pairs)} pairs ...\")\n",
    "t0 = time.time()\n",
    "results = Parallel(n_jobs=-1, backend='loky', verbose=1)(\n",
    "    delayed(compute_pair)(i, j) for i, j in pairs\n",
    ")\n",
    "\n",
    "sgw_dist_full = np.zeros((N, N), dtype=np.float64)\n",
    "for i, j, dist in results:\n",
    "    sgw_dist_full[i, j] = dist\n",
    "    sgw_dist_full[j, i] = dist\n",
    "print(f\"[INFO] Done precomputing. Time: {time.time() - t0:.2f}s\")\n",
    "\n",
    "\n",
    "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
    "\n",
    "acc_list = []\n",
    "fold_times = []\n",
    "\n",
    "for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(N)), start=1):\n",
    "    t_fold = time.time()\n",
    "    print(f\"\\n===== Fold {fold}/{K} =====\")\n",
    "\n",
    "    y_train = train_label_small[train_idx]\n",
    "    y_val = train_label_small[val_idx]\n",
    "\n",
    "    D_train = sgw_dist_full[np.ix_(train_idx, train_idx)]\n",
    "    D_val = sgw_dist_full[np.ix_(val_idx, train_idx)]\n",
    "\n",
    "    clf = KNeighborsClassifier(metric='precomputed', n_neighbors=n_neighbors)\n",
    "    clf.fit(D_train, y_train)\n",
    "\n",
    "    y_pred = clf.predict(D_val)\n",
    "    acc = (y_pred == y_val).mean() * 100\n",
    "    acc_list.append(acc)\n",
    "\n",
    "    fold_time = time.time() - t_fold\n",
    "    fold_times.append(fold_time)\n",
    "\n",
    "    print(f\"Fold {fold} Val Accuracy: {acc:.2f}% | Time: {fold_time:.2f}s\")\n",
    "\n",
    "print(\"\\n===== Final CV Result =====\")\n",
    "print(f\"Avg Acc: {np.mean(acc_list):.2f}% ± {np.std(acc_list):.2f}%\")\n",
    "print(f\"Avg Fold Time: {np.mean(fold_times):.2f}s | Total Time: {time.time() - t_all:.2f}s\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3907f11f",
   "metadata": {},
   "source": [
    "### SEINT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0975a831",
   "metadata": {},
   "outputs": [],
   "source": [
    "riot_kwargs = dict(    \n",
    "        rd = None,\n",
    "        initial = True,\n",
    "        lp = 2,\n",
    "        rep = 50,\n",
    "        maxed = True,\n",
    "        determin = False,\n",
    "        set_seed = 42,\n",
    "        rd_rad = 2,\n",
    "        acc = True\n",
    ")\n",
    "\n",
    "t_all = time.time()\n",
    "\n",
    "def compute_pair(i, j):\n",
    "    dist = SEINT(train_3D_small[i], train_3D_small[j], **riot_kwargs)\n",
    "    return i, j, dist\n",
    "\n",
    "pairs = [(i, j) for i in range(N) for j in range(i+1, N)]\n",
    "\n",
    "print(f\"[INFO] Start precomputing full {N}x{N} distance matrix with {len(pairs)} pairs ...\")\n",
    "t0 = time.time()\n",
    "results = Parallel(n_jobs=-1, backend='loky', verbose=1)(\n",
    "    delayed(compute_pair)(i, j) for i, j in pairs\n",
    ")\n",
    "rot_dist_full = np.zeros((N, N), dtype=np.float64)\n",
    "for i, j, dist in results:\n",
    "    rot_dist_full[i, j] = dist\n",
    "    rot_dist_full[j, i] = dist\n",
    "print(f\"[INFO] Done precomputing. Time: {time.time() - t0:.2f}s\")\n",
    "\n",
    "with open('rot_dist_full.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "    rot_dist_full = data['rot_dist_full']\n",
    "\n",
    "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
    "\n",
    "acc_list = []\n",
    "fold_times = []\n",
    "\n",
    "for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(N)), start=1):\n",
    "    t_fold = time.time()\n",
    "    print(f\"\\n===== Fold {fold}/{K} =====\")\n",
    "\n",
    "    y_train = train_label_small[train_idx]\n",
    "    y_val = train_label_small[val_idx]\n",
    "\n",
    "    D_train = rot_dist_full[np.ix_(train_idx, train_idx)]\n",
    "    D_val = rot_dist_full[np.ix_(val_idx, train_idx)]\n",
    "\n",
    "    clf = KNeighborsClassifier(metric='precomputed', n_neighbors=n_neighbors)\n",
    "    clf.fit(D_train, y_train)\n",
    "\n",
    "    y_pred = clf.predict(D_val)\n",
    "    acc = (y_pred == y_val).mean() * 100\n",
    "    acc_list.append(acc)\n",
    "\n",
    "    fold_time = time.time() - t_fold\n",
    "    fold_times.append(fold_time)\n",
    "\n",
    "    print(f\"Fold {fold} Val Accuracy: {acc:.2f}% | Time: {fold_time:.2f}s\")\n",
    "\n",
    "print(\"\\n===== Final CV Result =====\")\n",
    "print(f\"Avg Acc: {np.mean(acc_list):.2f}% ± {np.std(acc_list):.2f}%\")\n",
    "print(f\"Avg Fold Time: {np.mean(fold_times):.2f}s | Total Time: {time.time() - t_all:.2f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80584635",
   "metadata": {},
   "outputs": [],
   "source": [
    "iseint_kwargs = dict(    \n",
    "        rd = None,\n",
    "        initial = True,\n",
    "        lp = 2,\n",
    "        rep = 50,\n",
    "        maxed = False,\n",
    "        determin = False,\n",
    "        set_seed = 42,\n",
    "        rd_rad = 2,\n",
    "        acc = True\n",
    ")\n",
    "\n",
    "\n",
    "N = len(train_label_small)\n",
    "\n",
    "t_all = time.time()\n",
    "\n",
    "def compute_pair(i, j):\n",
    "    dist = SEINT(train_3D_small[i], train_3D_small[j], **iseint_kwargs)\n",
    "    return i, j, dist\n",
    "\n",
    "pairs = [(i, j) for i in range(N) for j in range(i+1, N)]\n",
    "\n",
    "print(f\"[INFO] Start precomputing full {N}x{N} distance matrix with {len(pairs)} pairs ...\")\n",
    "t0 = time.time()\n",
    "results = Parallel(n_jobs=-1, backend='loky', verbose=1)(\n",
    "    delayed(compute_pair)(i, j) for i, j in pairs\n",
    ")\n",
    "iseint_dist_full = np.zeros((N, N), dtype=np.float64)\n",
    "for i, j, dist in results:\n",
    "    iseint_dist_full[i, j] = dist\n",
    "    iseint_dist_full[j, i] = dist\n",
    "print(f\"[INFO] Done precomputing. Time: {time.time() - t0:.2f}s\")\n",
    "\n",
    "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
    "\n",
    "acc_list = []\n",
    "fold_times = []\n",
    "\n",
    "for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(N)), start=1):\n",
    "    t_fold = time.time()\n",
    "    print(f\"\\n===== Fold {fold}/{K} =====\")\n",
    "\n",
    "    y_train = train_label_small[train_idx]\n",
    "    y_val = train_label_small[val_idx]\n",
    "\n",
    "    D_train = iseint_dist_full[np.ix_(train_idx, train_idx)]\n",
    "    D_val = iseint_dist_full[np.ix_(val_idx, train_idx)]\n",
    "\n",
    "    clf = KNeighborsClassifier(metric='precomputed', n_neighbors=n_neighbors)\n",
    "    clf.fit(D_train, y_train)\n",
    "\n",
    "    y_pred = clf.predict(D_val)\n",
    "    acc = (y_pred == y_val).mean() * 100\n",
    "    acc_list.append(acc)\n",
    "\n",
    "    fold_time = time.time() - t_fold\n",
    "    fold_times.append(fold_time)\n",
    "\n",
    "    print(f\"Fold {fold} Val Accuracy: {acc:.2f}% | Time: {fold_time:.2f}s\")\n",
    "\n",
    "print(\"\\n===== Final CV Result =====\")\n",
    "print(f\"Avg Acc: {np.mean(acc_list):.2f}% ± {np.std(acc_list):.2f}%\")\n",
    "print(f\"Avg Fold Time: {np.mean(fold_times):.2f}s | Total Time: {time.time() - t_all:.2f}s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16ddd07e",
   "metadata": {},
   "source": [
    "### W_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17822bd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_sinkhorn = False      \n",
    "sinkhorn_reg = 1e-2      \n",
    "\n",
    "\n",
    "t_all = time.time()\n",
    "\n",
    "def w2_distance(Xi, Xj):\n",
    "    Xi = np.asarray(Xi, dtype=float)\n",
    "    Xj = np.asarray(Xj, dtype=float)\n",
    "    n, m = len(Xi), len(Xj)\n",
    "    a = np.ones((n,)) / n\n",
    "    b = np.ones((m,)) / m\n",
    "    M = ot.dist(Xi, Xj, metric='euclidean') ** 2  # (n, m)\n",
    "\n",
    "    cost = ot.emd2(a, b, M)\n",
    "    return float(np.sqrt(cost))\n",
    "\n",
    "def compute_pair(i, j):\n",
    "    dist = w2_distance(train_3D_small[i], train_3D_small[j])\n",
    "    return i, j, dist\n",
    "\n",
    "pairs = [(i, j) for i in range(N) for j in range(i + 1, N)]\n",
    "\n",
    "solver_name = \"sinkhorn\" if use_sinkhorn else \"emd\"\n",
    "print(f\"[INFO] Start precomputing full {N}x{N} W2 distance matrix ({solver_name}) with {len(pairs)} pairs ...\")\n",
    "t0 = time.time()\n",
    "results = Parallel(n_jobs=-1, backend='loky', verbose=1)(\n",
    "    delayed(compute_pair)(i, j) for i, j in pairs\n",
    ")\n",
    "\n",
    "w2_dist_full = np.zeros((N, N), dtype=np.float64)\n",
    "for i, j, dist in results:\n",
    "    w2_dist_full[i, j] = dist\n",
    "    w2_dist_full[j, i] = dist\n",
    "print(f\"[INFO] Done precomputing. Time: {time.time() - t0:.2f}s\")\n",
    "\n",
    "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
    "acc_list, fold_times = [], []\n",
    "\n",
    "for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(N)), start=1):\n",
    "    t_fold = time.time()\n",
    "    print(f\"\\n===== Fold {fold}/{K} =====\")\n",
    "\n",
    "    y_train = train_label_small[train_idx]\n",
    "    y_val = train_label_small[val_idx]\n",
    "\n",
    "    D_train = w2_dist_full[np.ix_(train_idx, train_idx)]\n",
    "    D_val = w2_dist_full[np.ix_(val_idx, train_idx)]\n",
    "\n",
    "    clf = KNeighborsClassifier(metric='precomputed', n_neighbors=n_neighbors)\n",
    "    clf.fit(D_train, y_train)\n",
    "\n",
    "    y_pred = clf.predict(D_val)\n",
    "    acc = (y_pred == y_val).mean() * 100\n",
    "    acc_list.append(acc)\n",
    "\n",
    "    fold_time = time.time() - t_fold\n",
    "    fold_times.append(fold_time)\n",
    "    print(f\"Fold {fold} Val Accuracy: {acc:.2f}% | Time: {fold_time:.2f}s\")\n",
    "\n",
    "print(\"\\n===== Final CV Result (W2) =====\")\n",
    "print(f\"Avg Acc: {np.mean(acc_list):.2f}% ± {np.std(acc_list):.2f}%\")\n",
    "print(f\"Avg Fold Time: {np.mean(fold_times):.2f}s | Total Time: {time.time() - t_all:.2f}s\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80ccfc50",
   "metadata": {},
   "source": [
    "### SW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb6adc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ot import sliced as ots\n",
    "num_projections = 50   \n",
    "\n",
    "\n",
    "t_all = time.time()\n",
    "\n",
    "def compute_pair(i, j):\n",
    "    Xi = np.asarray(train_3D_small[i], dtype=float)\n",
    "    Xj = np.asarray(train_3D_small[j], dtype=float)\n",
    "    dist = ots.sliced_wasserstein_distance(\n",
    "        Xi, Xj, n_projections=num_projections, seed=seed\n",
    "    )\n",
    "    return i, j, float(dist)\n",
    "\n",
    "pairs = [(i, j) for i in range(N) for j in range(i + 1, N)]\n",
    "\n",
    "print(f\"[INFO] Start precomputing full {N}x{N} SW distance matrix with {len(pairs)} pairs ...\")\n",
    "t0 = time.time()\n",
    "results = Parallel(n_jobs=-1, backend='loky', verbose=1)(\n",
    "    delayed(compute_pair)(i, j) for i, j in pairs\n",
    ")\n",
    "sw_dist_full = np.zeros((N, N), dtype=np.float64)\n",
    "for i, j, dist in results:\n",
    "    sw_dist_full[i, j] = dist\n",
    "    sw_dist_full[j, i] = dist\n",
    "print(f\"[INFO] Done precomputing. Time: {time.time() - t0:.2f}s\")\n",
    "\n",
    "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
    "acc_list, fold_times = [], []\n",
    "\n",
    "for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(N)), start=1):\n",
    "    t_fold = time.time()\n",
    "    print(f\"\\n===== Fold {fold}/{K} =====\")\n",
    "\n",
    "    y_train = train_label_small[train_idx]\n",
    "    y_val = train_label_small[val_idx]\n",
    "\n",
    "    D_train = sw_dist_full[np.ix_(train_idx, train_idx)]\n",
    "    D_val = sw_dist_full[np.ix_(val_idx, train_idx)]\n",
    "\n",
    "    clf = KNeighborsClassifier(metric='precomputed', n_neighbors=n_neighbors)\n",
    "    clf.fit(D_train, y_train)\n",
    "\n",
    "    y_pred = clf.predict(D_val)\n",
    "    acc = (y_pred == y_val).mean() * 100\n",
    "    acc_list.append(acc)\n",
    "\n",
    "    fold_time = time.time() - t_fold\n",
    "    fold_times.append(fold_time)\n",
    "    print(f\"Fold {fold} Val Accuracy: {acc:.2f}% | Time: {fold_time:.2f}s\")\n",
    "\n",
    "print(\"\\n===== Final CV Result (SW) =====\")\n",
    "print(f\"Avg Acc: {np.mean(acc_list):.2f}% ± {np.std(acc_list):.2f}%\")\n",
    "print(f\"Avg Fold Time: {np.mean(fold_times):.2f}s | Total Time: {time.time() - t_all:.2f}s\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "HW",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
