{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c80aef7-5a73-4e45-a144-53b596459538",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from gpcam import GPOptimizer\n",
    "#import matplotlib.pyplot as plt\n",
    "import random\n",
    "import ot\n",
    "from ot import fused_gromov_wasserstein2\n",
    "import dask\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06a33ca7-a599-4fec-a2f6-83e3097a7d84",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "\n",
    "with open('./data/x_train_MNIST.pkl', 'rb') as f:\n",
    "    x_train = pickle.load(f)\n",
    "with open('./data/x_test_MNIST.pkl', 'rb') as f:\n",
    "    x_test = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802f86cc-e629-45d1-9fc8-4e4f06622cc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = x_train + x_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e118ac8d-979c-443c-8cd8-d61d940b8f84",
   "metadata": {},
   "outputs": [],
   "source": [
    "del x_train\n",
    "del x_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efbe2389-2578-4ed3-91cb-5c99c638c1c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dask\n",
    "from dask.distributed import Client\n",
    "import os\n",
    "import time\n",
    "\n",
    "scheduler_file = os.path.join(os.environ[\"SCRATCH\"], \"scheduler_fileMNIST.json\")\n",
    "\n",
    "dask.config.config[\"distributed\"][\"dashboard\"][\"link\"] = \"{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status\" \n",
    "\n",
    "while True:\n",
    "    time.sleep(2)\n",
    "    if os.path.isfile(scheduler_file):\n",
    "        print(\"file found\")\n",
    "        time.sleep(1)\n",
    "        client = Client(scheduler_file=scheduler_file)\n",
    "        break\n",
    "print(\"waiting for workers\")\n",
    "client.wait_for_workers(256)\n",
    "print(client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0210b405-7f4f-49c3-b4fa-4a5352eb74b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5036fd1e-f54a-4132-a6ed-f6cc9e15b309",
   "metadata": {},
   "outputs": [],
   "source": [
    "#normalize\n",
    "for i in range(len(all_x)):\n",
    "    all_x[i] = all_x[i] / np.sum(all_x[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b530ae0d-e63d-40d0-9248-baefd518a4ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = len(all_x[0])  # Assuming all images have the same size\n",
    "y, x = np.mgrid[0:n, 0:n]\n",
    "locations = np.column_stack([y.flatten(), x.flatten()])\n",
    "# Ensure M is float type before division\n",
    "M = ot.dist(locations, locations).astype(np.float64)\n",
    "M /= M.max()  # Normalize cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eda5a4b2-2372-41c6-817e-0b85eefda7f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b9e7a59-9474-43c4-8467-ec55a40e2cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#def compute_w2_distance(scatterf, i ,j):\n",
    "\n",
    "N = 200\n",
    "\n",
    "def compute_w2_distance(image1, image2, M):\n",
    "    \"\"\"\n",
    "    Compute the Wasserstein-2 distance between two images.\n",
    "    Parameters:\n",
    "        image1 (numpy.ndarray): First image\n",
    "        image2 (numpy.ndarray): Second image\n",
    "    Returns:\n",
    "        float: The Wasserstein-2 distance between the images\n",
    "    \"\"\"\n",
    "    t1_start = time.perf_counter()\n",
    "    a = image1.flatten()\n",
    "    b = image2.flatten()\n",
    "    w2_dist = ot.emd2(a, b, M)\n",
    "    t1_stop = time.perf_counter()\n",
    "    #print(\"Elapsed w2 time:\", t1_stop - t1_start, flush = True)\n",
    "    return w2_dist\n",
    "    \n",
    "\n",
    "def comp_dist_block(imageBlock1, imageBlock2, M):\n",
    "    n1 = len(imageBlock1)\n",
    "    n2 = len(imageBlock2)\n",
    "\n",
    "    w2_distances = np.zeros((n1,n2))\n",
    "    for i in range(n1):\n",
    "        for j in range(n2):\n",
    "            img1 = imageBlock1[i]\n",
    "            img2 = imageBlock2[j]\n",
    "            w2_distances[i, j] = compute_w2_distance(img1, img2, M)\n",
    "    #print(w2_distances.shape)\n",
    "    return w2_distances\n",
    "\n",
    "    \n",
    "from scipy.spatial.distance import cosine\n",
    "def comp_dist_block1D(imageBlock1, imageBlock2, index_i , index_j):\n",
    "    n1 = len(imageBlock1)\n",
    "    n2 = len(imageBlock2)\n",
    "    t1_start = time.perf_counter()\n",
    "\n",
    "    w2_distances = np.zeros((n1,n2))\n",
    "    for i in range(n1):\n",
    "        for j in range(n2):\n",
    "            img1 = imageBlock1[i]\n",
    "            img2 = imageBlock2[j] \n",
    "            #we tried a couple different distance metrics: wasserstein, l1, cosine\n",
    "            \n",
    "            #marginal1_img1 = np.sum(img1, axis = 0)\n",
    "            #marginal1_img2 = np.sum(img2, axis = 0)\n",
    "            #marginal2_img1 = np.sum(img1, axis = 1) \n",
    "            #marginal2_img2 = np.sum(img2, axis = 1)\n",
    "\n",
    "            w2_distances[i, j] = np.sum(np.abs(img1 - img2))# wasserstein_1d(marginal1_img1, marginal1_img2) * wasserstein_1d(marginal2_img1, marginal2_img2)\n",
    "            #w2_distances[i, j] = cosine(img1.flatten(), img2.flatten())\n",
    "\n",
    "    t1_stop = time.perf_counter()\n",
    "    print(\"Elapsed time:\", t1_stop - t1_start, flush = True)\n",
    "    return index_i, index_j, w2_distances\n",
    "\n",
    "def compute_pairwise_distances(images1, images2, M):\n",
    "    n1 = len(images1)\n",
    "    n2 = len(images2)\n",
    "    \n",
    "    w2_futures = []\n",
    "    # Compute pairwise distances\n",
    "    for i in range(0,n1,N):\n",
    "        block1 = images1[i: i+N]\n",
    "        for j in range(i, n2, N):\n",
    "            block2 = images2[j: j+N]\n",
    "            #w2_futures.append(client.submit(comp_dist_block, block1, block2, M))\n",
    "            w2_futures.append(client.submit(comp_dist_block1D, block1, block2, i, j))\n",
    "    #w2_arrays = client.gather(w2_futures)\n",
    "    print(\"all submitted\")\n",
    "    \n",
    "    w2_distances = np.zeros((n1, n2), dtype=np.float32)\n",
    "    for future in w2_futures:\n",
    "        index_i, index_j, matrix = future.result()\n",
    "        w2_distances[index_i:index_i + N, index_j:index_j + N] = matrix\n",
    "        w2_distances[index_j:index_j + N, index_i:index_i + N] = matrix.T\n",
    "    return w2_distances\n",
    "\n",
    "\n",
    "def wasserstein_1d(a, b):\n",
    "    a = a/np.sum(a)\n",
    "    b = b/np.sum(b)\n",
    "    a_sorted = np.sort(a)\n",
    "    b_sorted = np.sort(b)\n",
    "    return np.mean(np.abs(a_sorted - b_sorted))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd9b5cac-eb6c-4a33-b4be-dbb1857e6b6c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "w2 = compute_pairwise_distances(all_x, all_x, M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d05c10c1-21f1-4ea9-9327-abe7f04a3bdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"distance_l1\", w2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f04a1c2-7ee8-4ce8-92f8-67f8a43832ce",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MNISTenv",
   "language": "python",
   "name": "mnistenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
