{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00b69bdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from src.distributions import LoaderSampler, TensorSampler\n",
    "\n",
    "import time\n",
    "\n",
    "from sklearn.metrics.pairwise import pairwise_distances\n",
    "\n",
    "import ot\n",
    "import torch\n",
    "import torch.distributions as TD\n",
    "import numpy as np\n",
    "\n",
    "import wandb\n",
    "\n",
    "class DiscreteEOT_l2sq_sampler:\n",
    "\n",
    "    @staticmethod\n",
    "    def discrete_sample_conditional(Y, G, i_x, n_pts, return_indices=False):\n",
    "        probs = G[i_x] / torch.sum(G[i_x])\n",
    "        distrib = TD.Categorical(probs = probs)\n",
    "        numbers = distrib.sample((n_pts,))\n",
    "        if not return_indices:\n",
    "            return Y[numbers]\n",
    "        return numbers\n",
    "\n",
    "    def __init__(self, X, Y, G, device='cpu'):\n",
    "        self.X = torch.tensor(X).float().to(device)\n",
    "        self.Y = torch.tensor(Y).float().to(device)\n",
    "        self.G = torch.tensor(G).float().to(device)\n",
    "        self.device = device\n",
    "\n",
    "    def sample(self, x_samples):\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def sample_by_indices(self, x_indices, return_indices=False):\n",
    "        spls = []\n",
    "        for x_idx in x_indices:\n",
    "            spls.append(self.discrete_sample_conditional(self.Y, self.G, x_idx, 1, return_indices=return_indices))\n",
    "        return torch.cat(spls, dim=0)\n",
    "\n",
    "    def sample_by_index(self, x_index, n, return_indices=False):\n",
    "        return self.discrete_sample_conditional(self.Y, self.G, x_index, n, return_indices=return_indices)\n",
    "\n",
    "def store_discrete_ot(path, model):\n",
    "    data = {\n",
    "        'X': model.X.detach().cpu(),\n",
    "        'Y': model.Y.detach().cpu(),\n",
    "        'G': model.G.detach().cpu(),\n",
    "    }\n",
    "    torch.save(data, path)\n",
    "\n",
    "def load_discrete_ot(path, device='cpu'):\n",
    "    CP = torch.load(path)\n",
    "    return DiscreteEOT_l2sq_sampler(CP['X'], CP['Y'], CP['G'], device=device)\n",
    "\n",
    "class DiscreteEOT_l2sq:\n",
    "\n",
    "    def _cast(self, x):\n",
    "        if self.dtype == 'torch32':\n",
    "            return torch.tensor(x).float().to(self.device)\n",
    "        if self.dtype == 'torch64':\n",
    "            return torch.tensor(x).double().to(self.device)\n",
    "        if self.dtype == 'np32':\n",
    "            return make_numpy(x).astype('float32')\n",
    "        if self.dtype == 'np64':\n",
    "            return make_numpy(x).astype('float64')\n",
    "\n",
    "    def __init__(\n",
    "        self, \n",
    "        verbose=False,\n",
    "        log=False,\n",
    "        method='sinkhorn_log', \n",
    "        stopThr=1e-09,\n",
    "        numItermax=10000,\n",
    "        dtype='torch32',\n",
    "        device='cpu',\n",
    "    ):\n",
    "        self.verbose = verbose\n",
    "        self.method = method\n",
    "        self.stopThr = stopThr\n",
    "        self.numItermax = numItermax\n",
    "        self.dtype = dtype\n",
    "        self.device = device\n",
    "        self.log = log\n",
    "        self.logs = None\n",
    "\n",
    "    def solve(self, X, Y, eps):\n",
    "        _X, _Y = self._cast(X), self._cast(Y)\n",
    "        M = 0.5 * ot.dist(_X, _Y)\n",
    "        xL, yL = X.shape[0], Y.shape[0]\n",
    "        wX, wY = self._cast(np.ones(xL)/xL), self._cast(np.ones(yL)/yL)\n",
    "        \n",
    "        ans = ot.sinkhorn(\n",
    "            wX, wY, M, eps, \n",
    "            method=self.method, \n",
    "            numItermax=self.numItermax, \n",
    "            stopThr=self.stopThr, \n",
    "            verbose=self.verbose,\n",
    "            log=self.log\n",
    "        )\n",
    "        \n",
    "        if self.log:\n",
    "            self.G, self.logs = ans\n",
    "        else:\n",
    "            self.G = ans\n",
    "            \n",
    "        return DiscreteEOT_l2sq_sampler(_X, _Y, self.G, device=self.device)\n",
    "    \n",
    "def ed(x, y):\n",
    "    Kxx = pairwise_distances(x, x)\n",
    "    Kyy = pairwise_distances(y, y)\n",
    "    Kxy = pairwise_distances(x, y)\n",
    "\n",
    "    m = x.shape[0]\n",
    "    n = y.shape[0]\n",
    "    \n",
    "    c1 = 1 / ( m * (m - 1))\n",
    "    A = np.sum(Kxx - np.diag(np.diagonal(Kxx)))\n",
    "\n",
    "    # Term II\n",
    "    c2 = 1 / (n * (n - 1))\n",
    "    B = np.sum(Kyy - np.diag(np.diagonal(Kyy)))\n",
    "\n",
    "    # Term III\n",
    "    c3 = 1 / (m * n)\n",
    "    C = np.sum(Kxy)\n",
    "\n",
    "    # estimate MMD\n",
    "    mmd_est = -0.5*c1*A - 0.5*c2*B + c3*C\n",
    "    \n",
    "    return mmd_est"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f27a6272",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40db80f3",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "DIM = 100\n",
    "assert DIM > 1\n",
    "\n",
    "SEED = 42\n",
    "BATCH_SIZE = 128\n",
    "EPSILON = 0.1\n",
    "\n",
    "DAY_START = 3\n",
    "DAY_END = 7\n",
    "DAY_EVAL = 4\n",
    "SERIES_ID = 0\n",
    "\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c069b3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED); np.random.seed(SEED)\n",
    "EPS = EPSILON\n",
    "EPSILON_END = EPSILON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "721fe5c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = f'Sinkhorn_Single_Cell_full_CITE_cell_DIM_{DIM}_DAY_EVAL_{DAY_EVAL}_EPSILON_{EPSILON}_SEED_{SEED}'\n",
    "OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    SERIES_ID=SERIES_ID,\n",
    "    DAY_START=DAY_START,\n",
    "    DAY_END=DAY_END,\n",
    "    DAY_EVAL=DAY_EVAL,\n",
    "    DIM=DIM,\n",
    "    EPSILON=EPSILON,\n",
    "    SEED=SEED,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f32f9151",
   "metadata": {},
   "source": [
    "## Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99c5d0d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for day in [2, 3, 4, 7]:\n",
    "    data[day] = np.load(f\"../data/full_cite_pcas_{DIM}_day_{day}.npy\")\n",
    "    \n",
    "eval_data = data[DAY_EVAL]\n",
    "start_data = data[DAY_START]\n",
    "end_data = data[DAY_END]\n",
    "\n",
    "constant_scale = np.concatenate([start_data, end_data, eval_data]).std(axis=0).mean()\n",
    "\n",
    "eval_data_scaled = eval_data/constant_scale\n",
    "start_data_scaled = start_data/constant_scale\n",
    "end_data_scaled = end_data/constant_scale\n",
    "\n",
    "eval_data = torch.tensor(eval_data).float()\n",
    "start_data = torch.tensor(start_data_scaled).float()\n",
    "end_data = torch.tensor(end_data_scaled).float()\n",
    "\n",
    "X_sampler = TensorSampler(torch.tensor(start_data).float(), device=\"cpu\")\n",
    "Y_sampler = TensorSampler(torch.tensor(end_data).float(), device=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7acd5691",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "distance = torch.cdist(start_data, end_data).double().square().cuda()/2\n",
    "ot_solver = DiscreteEOT_l2sq(verbose=True, log=True, device=\"cuda\", dtype=\"torch64\", stopThr=1e-8, numItermax=100000)\n",
    "sampler = ot_solver.solve(start_data.cuda(), end_data.cuda(), eps=0.1)\n",
    "\n",
    "logs = torch.stack(ot_solver.logs['err']).detach().cpu()\n",
    "\n",
    "for i, log in enumerate(logs):\n",
    "    wandb.log({\"tol\": log}, step=i)\n",
    "\n",
    "sinkhorn_samples = sampler.sample_by_indices(np.arange(start_data.shape[0])).cpu()\n",
    "\n",
    "predict_time = torch.tensor((DAY_EVAL - DAY_START)/(DAY_END-DAY_START))\n",
    "predict = torch.sqrt(predict_time*(1-predict_time)*EPSILON)*torch.randn_like(start_data) + (1-predict_time)*start_data + predict_time*sinkhorn_samples\n",
    "predict = predict * constant_scale\n",
    "\n",
    "ED = ed(predict, eval_data)\n",
    "\n",
    "wandb.log({\"ED\": ED}, step=i)\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50bca897",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
