{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fb2b2856",
   "metadata": {},
   "source": [
    "# UOT-WFM: majority visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "278f89ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, './Desktop/wfm_project/uot-wfm')\n",
    "#sys.path.insert(0, '../../torchcfm')\n",
    "#     set    .\n",
    "unique_paths = list(set(sys.path))\n",
    "#    .\n",
    "#  , '../../conditional-flow-matching'    .\n",
    "if '../../conditional-flow-matching' in unique_paths:\n",
    "    unique_paths.remove('../../conditional-flow-matching')\n",
    "#    sys.path .\n",
    "sys.path = unique_paths\n",
    "#    .\n",
    "# print(sys.path)\n",
    "\n",
    "def sample_conditional_pt(x0, x1, t, sigma):\n",
    "    \"\"\"\n",
    "    Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    x0 : Tensor, shape (bs, *dim)\n",
    "        represents the source minibatch\n",
    "    x1 : Tensor, shape (bs, *dim)\n",
    "        represents the target minibatch\n",
    "    t : FloatTensor, shape (bs)\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    xt : Tensor, shape (bs, *dim)\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Anonymous et al.\n",
    "    \"\"\"\n",
    "    t = t.reshape(-1, *([1] * (x0.dim() - 1)))\n",
    "    mu_t = t * x1 + (1 - t) * x0\n",
    "    epsilon = torch.randn_like(x0)\n",
    "    return mu_t + sigma * epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3460f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_conditional_vector_field(x0, x1):\n",
    "    \"\"\"\n",
    "    Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    x0 : Tensor, shape (bs, *dim)\n",
    "        represents the source minibatch\n",
    "    x1 : Tensor, shape (bs, *dim)\n",
    "        represents the target minibatch\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    ut : conditional vector field ut(x1|x0) = x1 - x0\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Anonymous et al.\n",
    "    \"\"\"\n",
    "    return x1 - x0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd2a2b32",
   "metadata": {},
   "source": [
    "# Flow Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2035a615",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ot as pot\n",
    "import torch\n",
    "import torchdyn\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchdyn.datasets import generate_moons\n",
    "\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.models import *\n",
    "from torchcfm.utils import *\n",
    "\n",
    "savedir = \"models/unbalanced-gaussian\"\n",
    "os.makedirs(savedir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf18883",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "sigma = 0.1\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "start = time.time()\n",
    "for k in range(1):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    x0 = sample_unbalanced_kgaussians(batch_size, 1, [(0,0)], [(1,1)], [1])\n",
    "    x1 = sample_unbalanced_kgaussians(batch_size, 3, [(4,-4), (8,3), (9,-5)], [(1,1), (1,1), (1,1)], [0.1, 0.7, 0.2])\n",
    "\n",
    "    t = torch.rand(x0.shape[0]).type_as(x0)\n",
    "    xt = sample_conditional_pt(x0, x1, t, sigma=0.01)\n",
    "    ut = compute_conditional_vector_field(x0, x1)\n",
    "\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                sample_unbalanced_kgaussians(1024, 1, [(0,0)], [(1,1)], [1]),\n",
    "                t_span=torch.linspace(0, 1, 100),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d9e10d",
   "metadata": {},
   "source": [
    "# UOT-WFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7831d15",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "#if 'torchcfm.optimal_transport' in sys.modules:\n",
    "#    del sys.modules['torchcfm.optimal_transport']\n",
    "#if 'torchcfm.optimal_transport' in sys.modules:\n",
    "#    importlib.reload(sys.modules['torchcfm.optimal_transport'])\n",
    "\n",
    "from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "plan = 'icfm' # icfm, exact, uot_fm, uot_wfm\n",
    "weight_type = \"none\" # inv_tnu, none\n",
    "epochs = 20000\n",
    "visual_interval = epochs\n",
    "\n",
    "if plan == 'icfm':\n",
    "    ot_sampler = None\n",
    "elif plan == 'exact':\n",
    "    ot_sampler = OTPlanSampler(method=\"exact\")\n",
    "elif plan == 'uot_fm':\n",
    "    ot_sampler = OTPlanSampler(method=\"sinkhorn\", reg=1.0)\n",
    "elif plan == 'uot_wfm':\n",
    "    ot_sampler = OTPlanSampler(method=\"unbalanced_knopp\",reg=0.1, reg_m=(float(\"inf\"), float(\"inf\"))) # float(\"inf\") can be used\n",
    "\n",
    "\n",
    "\n",
    "sigma = 0.1\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "\n",
    "weight_power_factor = 1.0\n",
    "\n",
    "source_centers = [(0,0)]\n",
    "source_variances = [(1,1)]\n",
    "source_weights = [1]\n",
    "\n",
    "target_centers = [(10,-10), (10, 10)]\n",
    "target_variances = [(0.01,0.01), (1,1)]\n",
    "target_weights = [0.01, 0.99]\n",
    "\n",
    "start = time.time()\n",
    "for k in range(epochs):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    x0 = sample_unbalanced_kgaussians(batch_size, 1, source_centers, source_variances, source_weights)\n",
    "    x1 = sample_unbalanced_kgaussians(batch_size, 2, target_centers, target_variances, target_weights)\n",
    "\n",
    "\n",
    "    #print(\"original x1 duplicates\",count_duplicate_rows(x1))\n",
    "    \n",
    "\n",
    "    # Draw samples from OT plan\n",
    "    if plan == 'icfm':\n",
    "        x0_ot, x1_ot = x0, x1\n",
    "    elif plan == 'exact':\n",
    "        x0_ot, x1_ot = ot_sampler.sample_plan(x0, x1) # for exact OT\n",
    "    elif plan == 'uot_fm':\n",
    "        x0_ot, x1_ot, pi, u, v, i, j = ot_sampler.sample_plan_with_weights_and_indices(x0, x1, fixed_source=True) # for unbalanced OT\n",
    "    elif plan == 'uot_wfm':\n",
    "        x0_ot, x1_ot, pi, u, v, i, j = ot_sampler.sample_plan_with_weights_and_indices(x0, x1, fixed_source=True) # for unbalanced OT\n",
    "    #print(f\"u, zero: {np.sum(u==0)}, not zero: {np.sum(u!=0)}\")\n",
    "    #print(f\"v, zero: {np.sum(v==0)}, not zero: {np.sum(v!=0)}\")\n",
    "\n",
    "    #visual_scale_weight = 1/100\n",
    "    #visualize_pi(pi, title=\"pi\", xlabel=\"target\", ylabel=\"source\", fixed_value_scale=visual_scale_weight*(1/256))\n",
    "    #break\n",
    "\n",
    "    # WEIGHT\n",
    "    if weight_type == \"inv_tnu\" and plan == 'uot_wfm': # UOT-WFM\n",
    "        tnu = pi.sum(dim=0)\n",
    "        tnu = tnu.reshape(tnu.size(0), 1)\n",
    "        tnu = tnu / (1/x1.size(0)) # normalizaed by batch size\n",
    "        fm_weight = 1 / tnu.detach() # inverse weight (minority)\n",
    "        fm_weight = fm_weight[j]\n",
    "        fm_weight = fm_weight ** weight_power_factor\n",
    "    elif weight_type == \"inv_tnu\" and plan != 'uot_wfm':\n",
    "        raise ValueError(\"inv_tnu is only supported for uot_wfm\")\n",
    "    else:\n",
    "        fm_weight = 1.0\n",
    "\n",
    "    #print(\"OT x1 duplicates\",count_duplicate_rows(x1))\n",
    "    #break\n",
    "\n",
    "    t = torch.rand(x0_ot.shape[0]).type_as(x0_ot)\n",
    "    xt = sample_conditional_pt(x0_ot, x1_ot, t, sigma=0.01)\n",
    "    ut = compute_conditional_vector_field(x0_ot, x1_ot)\n",
    "\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean(((vt - ut) ** 2) * fm_weight)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % visual_interval == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                sample_unbalanced_kgaussians(batch_size*100, 1, source_centers, source_variances, source_weights),\n",
    "                t_span=torch.linspace(0, 1, 100),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy(), vis_grid=True)\n",
    "            plot_sample_points(x0, x1, x0_ot, x1_ot)\n",
    "            print_majority_ratios(x1, centers=target_centers, names=[\"(10, -10) origin\",  \"(10, 10) origin\"])\n",
    "            print_majority_ratios(x1_ot, centers=target_centers, names=[\"(10, -10) uot_resample\", \"(10, 10) uot_resample\"])\n",
    "            print_majority_ratios(traj[-1], centers=target_centers, names=[\"(10, -10) traj_end\", \"(10, 10) traj_end\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76d09989",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7348050",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "120c0132",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41411446",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}