{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.distributions as TD\n",
    "\n",
    "from matplotlib import collections  as mc\n",
    "import numpy as np\n",
    "import random\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.legend_handler import HandlerTuple\n",
    "%matplotlib inline\n",
    "\n",
    "from typing import Dict, Any, Literal, List, Tuple, Union, Optional\n",
    "from tqdm import tqdm\n",
    "import itertools\n",
    "from copy import deepcopy\n",
    "\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.utils import Config, make_f_pot, freeze, unfreeze\n",
    "from src.models import linear_model\n",
    "from src.cost import strong_cost\n",
    "\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='6'\n",
    "\n",
    "def seed_everything(\n",
    "    seed: int,\n",
    "    *,\n",
    "    avoid_benchmark_noise: bool = False,\n",
    "    only_deterministic_algorithms: bool = False\n",
    "):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    torch.backends.cudnn.benchmark = not avoid_benchmark_noise\n",
    "    torch.use_deterministic_algorithms(only_deterministic_algorithms, warn_only=True)\n",
    "\n",
    "def sample_gauss(mu, cov, n):\n",
    "    \"\"\"\n",
    "    mu - torch.Size([2])\n",
    "    cov - torch.Size([2,2])\n",
    "    n - int (amount of samples)\n",
    "    \"\"\"\n",
    "    dist =  TD.MultivariateNormal(mu, cov)\n",
    "    return dist.sample(torch.Size([n]))\n",
    "\n",
    "def plot_initial_data(mus,covs,n):\n",
    "    \"\"\"\n",
    "    mus - list of torch.Size([2])\n",
    "    covs - list of torch.Size([2,2])\n",
    "    n - int (amount of samples)\n",
    "    \"\"\"\n",
    "    \n",
    "    for idx,mu,cov in zip(range(len(mus)), mus,covs):\n",
    "        d = sample_gauss(mu, cov, n)\n",
    "        plt.scatter(d[:,0],d[:,1],edgecolor='black',label=f'distribution {idx+1}')\n",
    "        plt.grid()\n",
    "        plt.legend()\n",
    "\n",
    "\n",
    "# add this class for m\n",
    "class MinValue(nn.Module):\n",
    "    def __init__(self, device):\n",
    "        super().__init__()\n",
    "        self.m = nn.Parameter(torch.zeros(1).to(device))\n",
    "\n",
    "    def forward(self):\n",
    "        return self.m\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, *hidden_dims: int):\n",
    "        \"\"\"Sequential linear layers with the ReLU activation.\n",
    "        \n",
    "        ReLU is applied between all layers. A number of layers equals\n",
    "        `len(hidden_dims) - 1`. The first and the last hidden dims are treated as the \n",
    "        input and the output dimensions of the backbone.\n",
    "        \"\"\"\n",
    "        assert len(hidden_dims) >= 2\n",
    "        super().__init__()\n",
    "        \n",
    "        inp, *hidden_dims = hidden_dims\n",
    "        self._layers = nn.Sequential(nn.Linear(inp, hidden_dims[0]))\n",
    "        for inp, out in zip(hidden_dims[:-1], hidden_dims[1:]):\n",
    "            self._layers.append(nn.ReLU(inplace=True))\n",
    "            self._layers.append(nn.Linear(inp, out))\n",
    "#         self._layers.append(NegAbs())\n",
    "        \n",
    "    def forward(self, x): return self._layers(x)\n",
    "\n",
    "class OTMap(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        inp_dim: int = None,\n",
    "        hidden_dims: List[int] = None,\n",
    "        out_dim: int = None,\n",
    "        *args, **kwargs,\n",
    "    ):\n",
    "        \"\"\"Initialize OT map class.\n",
    "        \n",
    "        Args:\n",
    "            inp_dim: a dimensionality of the source space.\n",
    "            out_dim: a dimensionality of the target space.\n",
    "            hidden_dims: hidden dimensions.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(\n",
    "        self, \n",
    "        x: torch.FloatTensor,\n",
    "        reg: bool = False,\n",
    "    ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:\n",
    "        \"\"\"Compute OT Map.\n",
    "        \n",
    "        If the map is weak, return one sample per input item.\n",
    "        \n",
    "        Args:\n",
    "            x: tensor of shape (bs, inp_dim)\n",
    "            reg: wether to return the regularization term\n",
    "        \n",
    "        Returns:\n",
    "            tensor of shape (bs, out_dim) [and regularization term]\n",
    "        \"\"\"\n",
    "        \n",
    "        raise NotImplementedError\n",
    "\n",
    "class DeterministicMap(OTMap):\n",
    "    def __init__(self, inp_dim: int, hidden_dims: List[int], out_dim: int):\n",
    "        super().__init__()\n",
    "        self._bb = MLP(inp_dim, *hidden_dims, out_dim)\n",
    "        \n",
    "    def forward(self, x, reg: bool = False):\n",
    "        out = self._bb(x)\n",
    "        if reg:\n",
    "            return out, torch.tensor(0.0, device=x.device)\n",
    "        return out\n",
    "\n",
    "class NoiseInputMap(OTMap):\n",
    "    def __init__(\n",
    "        self,\n",
    "        inp_dim: int,\n",
    "        hidden_dims: List[int],\n",
    "        out_dim: int,\n",
    "        prior: torch.distributions.Distribution,\n",
    "        noise_dim: Optional[int] = None,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self._noise_dim = noise_dim or inp_dim\n",
    "        self._prior = prior\n",
    "        self._bb = MLP(inp_dim + self._noise_dim, *hidden_dims, out_dim)\n",
    "        \n",
    "    def forward(self, x, reg: bool = False):\n",
    "        bs = x.shape[0]\n",
    "        dev = x.device\n",
    "        \n",
    "        noise = torch.randn(bs, self._noise_dim, device=dev)\n",
    "        x = torch.cat((x, noise), dim=-1)\n",
    "        out = self._bb(x)\n",
    "        ed = self.energy_dist_reg_sample(out)\n",
    "        \n",
    "        if reg:\n",
    "            return out, ed\n",
    "        return out\n",
    "        \n",
    "    def energy_dist_reg_sample(\n",
    "        self,\n",
    "        sample: torch.FloatTensor,\n",
    "    ):\n",
    "        \"\"\"Compute energy distance (only sample-dependent terms) using sample estimate.\n",
    "\n",
    "        Args:\n",
    "            sample: has shape (bs, d)\n",
    "            prior: torch distribution of item shape (d,)\n",
    "\n",
    "        Returns:\n",
    "            tensor of shape (bs,)\n",
    "        \"\"\"\n",
    "        pr_sample_1, pr_sample_2 = self._prior.sample((2, *sample.shape[:-1]))\n",
    "        l12 = (sample - pr_sample_1).norm(dim=1)\n",
    "        l11 = (pr_sample_1 - pr_sample_2).norm(dim=1)\n",
    "        return 2 * l12 - l11\n",
    "\n",
    "class Pots(nn.Module):\n",
    "    # TODO: optimize when 2 potentials\n",
    "    def __init__(self, bary_weights, *dims):\n",
    "        assert len(bary_weights) > 1\n",
    "        super().__init__()\n",
    "        self._lambdas = bary_weights\n",
    "        self._net = MLP(*dims)\n",
    "        \n",
    "    def __getitem__(self, idx):\n",
    "        assert 0 <= idx < 2 # only when there are two prob\n",
    "        \n",
    "        if idx == 0:\n",
    "            def f_pot(x, m): # include m\n",
    "                res = 0.0\n",
    "                res += self._net(x)\n",
    "                res += m / len(self._lambdas) / self._lambdas[idx] # include m\n",
    "                return res\n",
    "        else:\n",
    "            def f_pot(x, m): # include m\n",
    "                res = 0.0\n",
    "                res -= self._net(x)\n",
    "                res += m / len(self._lambdas) / self._lambdas[idx] # include m\n",
    "                return res\n",
    "\n",
    "        return f_pot\n",
    "\n",
    "def get_opt_sched(model, lr, total_steps):\n",
    "    opt = torch.optim.Adam(model.parameters(), lr)\n",
    "    sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "        opt,\n",
    "        lr,\n",
    "        total_steps=total_steps,\n",
    "    )\n",
    "    \n",
    "    return opt, sched\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = Config()\n",
    "\n",
    "CONFIG.GPU_DEVICE = 0\n",
    "assert torch.cuda.is_available()\n",
    "CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'\n",
    "\n",
    "CONFIG.K = 2  # amount of distributions\n",
    "CONFIG.LAMBDAS = [0.5,0.5]\n",
    "CONFIG.DIM = 2\n",
    "CONFIG.INPUT_DIM = CONFIG.DIM\n",
    "CONFIG.HIDDEN_DIMS = [128,128]\n",
    "CONFIG.OUTPUT_DIM_POT = 1\n",
    "CONFIG.OUTPUT_DIM_MAP = CONFIG.DIM\n",
    "CONFIG.LR = 1e-3\n",
    "CONFIG.NUM_SAMPLES = 10_000\n",
    "CONFIG.NUM_EPOCHS = 1200\n",
    "CONFIG.BATCH_SIZE= 1024\n",
    "CONFIG.INNER_ITERATIONS = 3\n",
    "\n",
    "CONFIG.PRIOR_MEAN = torch.tensor([5., 5.], device=CONFIG.DEVICE)\n",
    "CONFIG.PRIOR_COV = 2 * torch.eye(2, device=CONFIG.DEVICE)\n",
    "CONFIG.CONDITIONAL_COV = .1 * torch.eye(2, device=CONFIG.DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define data\n",
    "from toy_data import get_toydataset, ToySampler\n",
    "data_name1, data_name2 = 'moon', 'spiral'\n",
    "dataset1 = get_toydataset(data_name1, 2)\n",
    "sampler1 = ToySampler(dataset1)\n",
    "dataset2 = get_toydataset(data_name2, 2)\n",
    "sampler2 = ToySampler(dataset2)\n",
    "\n",
    "datas = [sampler1, sampler2]\n",
    "\n",
    "def plot_initial_data(n):\n",
    "    \"\"\"\n",
    "    mus - list of torch.Size([2])\n",
    "    covs - list of torch.Size([2,2])\n",
    "    n - int (amount of samples)\n",
    "    \"\"\"\n",
    "    \n",
    "    for idx,k in enumerate(range(CONFIG.K)):\n",
    "        d = datas[k].sample([n])\n",
    "        plt.scatter(d[:,0].cpu(),d[:,1].cpu(),edgecolor='black',label=f'distribution {idx+1}')\n",
    "        plt.axis(\"equal\")\n",
    "        plt.grid()\n",
    "        plt.legend()\n",
    "\n",
    "plot_initial_data(2_000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(\n",
    "    maps: OTMap, maps_opt, maps_sched, \n",
    "    pots: Pots, pots_opt, pots_sched,\n",
    "    mvalue, mvalue_opt, mvalue_sched,\n",
    "    reg_coeff: float = 0.0, tau: float = 1.0,\n",
    "    divergence1: str = 'balanced',\n",
    "    divergence2: str = 'balanced',\n",
    "):\n",
    "    divergences = [divergence1, divergence2]\n",
    "    losses = []\n",
    "    for epoch in tqdm(range(CONFIG.NUM_EPOCHS)):\n",
    "        \n",
    "        freeze(pots)\n",
    "        freeze(mvalue)\n",
    "        unfreeze(maps)\n",
    "        \n",
    "        #inner loop\n",
    "        for it in range(CONFIG.INNER_ITERATIONS):\n",
    "            for n in range(1):\n",
    "                # data sampling\n",
    "                data = [\n",
    "                    datas[k].sample([CONFIG.BATCH_SIZE]).to(CONFIG.DEVICE)\n",
    "                    for k in range(CONFIG.K)\n",
    "                ]\n",
    "\n",
    "                maps_opt.zero_grad()\n",
    "                loss = 0\n",
    "                for k in range(CONFIG.K):\n",
    "                    mapped_x_k, reg = maps[k](data[k], reg=True)  # [B, N]\n",
    "                    cost = strong_cost(data[k], mapped_x_k)  # [B, 1]\n",
    "                    m = mvalue()\n",
    "                    cost -= pots[k](mapped_x_k, m)  # [B, 1]\n",
    "                    cost += reg_coeff * torch.unsqueeze(reg, -1)\n",
    "                    cost = cost.mean(dim=0)\n",
    "                    loss += CONFIG.LAMBDAS[k] * cost\n",
    "                loss.backward()\n",
    "                maps_opt.step()\n",
    "                maps_sched.step()\n",
    "\n",
    "        # unfreezing potentials \n",
    "        # freezing maps\n",
    "        unfreeze(pots)\n",
    "        unfreeze(mvalue)\n",
    "        freeze(maps)\n",
    "        \n",
    "        # outer optimiztion\n",
    "        pots_opt.zero_grad()\n",
    "        mvalue_opt.zero_grad()\n",
    "        loss=0\n",
    "\n",
    "        m = mvalue()\n",
    "        cost = None\n",
    "        for k in range(CONFIG.K):\n",
    "            mapped_x_k, reg = maps[k](data[k], reg=True)  # [B, N]\n",
    "            \n",
    "            cost = strong_cost(data[k], mapped_x_k)  # [B, 1]\n",
    "            \n",
    "            cost -= pots[k](mapped_x_k, m)  # [B, 1]\n",
    "            if divergences[k] == 'kl':\n",
    "                cost = - tau * (torch.exp(-cost/tau) - 1)\n",
    "            elif divergences[k] == 'balanced':\n",
    "                pass\n",
    "            cost += m\n",
    "            cost += reg_coeff * torch.unsqueeze(reg, -1)\n",
    "            cost = cost.mean(dim=0)\n",
    "            loss += CONFIG.LAMBDAS[k] * cost\n",
    "        \n",
    "        loss = -1*loss\n",
    "        losses.append(loss.item())\n",
    "        loss.backward()\n",
    "        pots_opt.step()\n",
    "        pots_sched.step()\n",
    "        mvalue_opt.step()\n",
    "        mvalue_sched.step()\n",
    "        \n",
    "        \n",
    "        # plotting part\n",
    "        if epoch % 200 ==0 :\n",
    "            print(mvalue())\n",
    "            data = [datas[k].sample([1_000]).to(CONFIG.DEVICE)\n",
    "                        for k in range(CONFIG.K)]\n",
    "            \n",
    "            clear_output(wait=True)\n",
    "            fig, (ax, ax_l) = plt.subplots(1, 2, figsize=(12.8, 4.8))\n",
    "            for k in range(CONFIG.K):\n",
    "                d = maps[k](data[k]).detach().cpu()\n",
    "                ax.scatter(data[k][:,0].cpu(),data[k][:,1].cpu(),edgecolor='black',label=f'data {k+1}')\n",
    "                ax.scatter(d[:,0],d[:,1],edgecolor='black',label=f'barycenter {k+1}')\n",
    "                ax.grid()\n",
    "                ax.legend()\n",
    "                # ax.set_xlim(-8, 8)\n",
    "                # ax.set_ylim(-8, 8)\n",
    "                \n",
    "            ax_l.plot(losses)\n",
    "            plt.show()\n",
    "\n",
    "Tau = 5\n",
    "seed_everything(0, avoid_benchmark_noise=True)\n",
    "CONFIG.NUM_EPOCHS = 10000\n",
    "\n",
    "maps_ur = nn.ModuleList([\n",
    "    DeterministicMap(CONFIG.INPUT_DIM, CONFIG.HIDDEN_DIMS, CONFIG.OUTPUT_DIM_MAP)\n",
    "    for _ in range(CONFIG.K)\n",
    "]).to(CONFIG.DEVICE)\n",
    "maps_opt, maps_sched = get_opt_sched(maps_ur, CONFIG.LR, CONFIG.NUM_EPOCHS * CONFIG.INNER_ITERATIONS)\n",
    "\n",
    "pots_ur = Pots(\n",
    "    CONFIG.LAMBDAS,\n",
    "    CONFIG.INPUT_DIM,\n",
    "    *CONFIG.HIDDEN_DIMS,\n",
    "    CONFIG.OUTPUT_DIM_POT\n",
    ").to(CONFIG.DEVICE)\n",
    "pots_opt, pots_sched = get_opt_sched(pots_ur, CONFIG.LR, CONFIG.NUM_EPOCHS)\n",
    "\n",
    "# add m\n",
    "mvalue = MinValue(CONFIG.DEVICE)\n",
    "mvalue_opt = torch.optim.Adam(mvalue.parameters(), CONFIG.LR, (0, 0.9))\n",
    "mvalue_sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "        mvalue_opt,\n",
    "        CONFIG.LR,\n",
    "        total_steps=CONFIG.NUM_EPOCHS,\n",
    "    )\n",
    "\n",
    "train(maps_ur, maps_opt, maps_sched, pots_ur, pots_opt, pots_sched, mvalue, mvalue_opt, mvalue_sched, tau=Tau, divergence1='balanced', divergence2='kl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save transport maps\n",
    "parent = f'ckpts/COMPARISONS_{data_name1}_{data_name2}_Tau{Tau}'\n",
    "EXP_DIR = os.path.join(parent, f'UOTbary')\n",
    "os.makedirs(EXP_DIR, exist_ok=True)\n",
    "\n",
    "for k, map in enumerate(maps_ur):\n",
    "    torch.save(map.state_dict(), os.path.join(EXP_DIR, f'net{k}_epoch_{CONFIG.NUM_EPOCHS}.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Tau=%d'%Tau)\n",
    "\n",
    "divergences = [divergence1, divergence2]\n",
    "data = [datas[k].sample([1000]).to(CONFIG.DEVICE) for k in range(CONFIG.K)]\n",
    "\n",
    "TRAINED_BARYCENTERS = {}\n",
    "\n",
    "clear_output(wait=True)\n",
    "fig, (ax, ax_l) = plt.subplots(1, 2, figsize=(12.8, 4.8))\n",
    "for k in range(CONFIG.K):\n",
    "    U = torch.rand(len(data[k]))\n",
    "    mapped_x_k, _ = maps_ur[k](data[k], reg=True)  # [B, N]\n",
    "    f_c = strong_cost(data[k], mapped_x_k)  # [B, 1]\n",
    "    m = mvalue()\n",
    "    f_c -= pots_ur[k](mapped_x_k, m)\n",
    "    \n",
    "    c = torch.max(torch.exp(-f_c/Tau)).cpu()\n",
    "\n",
    "    if divergences[k] == 'balanced':\n",
    "        print(k)\n",
    "        acc = data[k]\n",
    "    elif divergences[k] == 'kl':\n",
    "        acc = data[k][U < 1/c * torch.exp(-f_c/Tau).flatten().detach().cpu()]\n",
    "        rej = data[k][U >= 1/c * torch.exp(-f_c/Tau).flatten().detach().cpu()]\n",
    "    \n",
    "    mapped_x_k_acc, _ = maps_ur[k](acc, reg=True)\n",
    "    \n",
    "    TRAINED_BARYCENTERS[f'mu_{k}'] = acc.detach().cpu().numpy()\n",
    "    TRAINED_BARYCENTERS[f'barycenter_{k}'] = mapped_x_k_acc.detach().cpu().numpy()\n",
    "\n",
    "    ax.scatter(data[k][:,0].cpu(),data[k][:,1].cpu(), edgecolor='black',label=f'data {k+1}')\n",
    "    ax.scatter(acc[:,0].cpu(),acc[:,1].cpu(),color='black',edgecolor='black',label=f'accepted {k+1}')\n",
    "    # if k==0:\n",
    "    lines =  list(zip(acc[:64].cpu().numpy(), mapped_x_k_acc[:64].cpu().numpy()))\n",
    "    lc = mc.LineCollection(lines, linewidths=0.5, color='white', zorder=4)\n",
    "    ax.scatter(mapped_x_k_acc[:,0].cpu(), mapped_x_k_acc[:,1].cpu(), edgecolor='black',label=f'barycenter {k+1}')\n",
    "    ax.add_collection(lc)\n",
    "    ax.grid()\n",
    "    ax.legend()\n",
    "    # ax.set_xlim(-8, 8)\n",
    "    # ax.set_ylim(-8, 8)\n",
    "\n",
    "plt.show()\n",
    "print(mvalue())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
