{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e1bc3506-f509-4b37-8179-15e7a944c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "sys.path.append('/home/quickjkee/projects/Light-GCOT')\n",
    "\n",
    "from scipy import linalg\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "from sklearn.preprocessing import Normalizer\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "import moscot.plotting as mtp\n",
    "import scipy\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from moscot import datasets\n",
    "from moscot.problems.cross_modality import TranslationProblem\n",
    "from sklearn import preprocessing as pp\n",
    "\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.from_loader import PairedLoaderSampler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.models.light_gcot import LightGCOT\n",
    "from sklearn.manifold import TSNE\n",
    "tsne = TSNE(n_components=2, random_state=50)\n",
    "\n",
    "#https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e94208ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.costs.lse import MLPLSECost\n",
    "from src.models.gmm_based import GMMEOT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9a6ad3",
   "metadata": {},
   "source": [
    "# Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2366fd9",
   "metadata": {},
   "source": [
    "### PS\n",
    "1) Source  \n",
    "$X \\in \\mathbb{R}^{N \\times d_{1}}, N - \\text{number of locations}, d_{1} - \\text{features dim}$ \\\n",
    "$x = (\\mu, \\sigma) - \\text{for a given location in June}$ \\\n",
    "$N = 1396, d_{1} = 188$ \n",
    "\n",
    "2) $Y \\in \\mathbb{R}^{N \\times M \\times d_{2}}, N - \\text{number of locations}, M - \\text{measurements for a given location in January by day}$ \\\n",
    "$M = [1, 31], d_{2} = 94$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "02965c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################################\n",
    "#-------------- RAW DATA -----------------\n",
    "##########################################\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "root = '../tabred/kal/weather'\n",
    "\n",
    "data = np.load(f'{root}/X_num.npy')\n",
    "data = np.stack([d for d in data if sum(np.isnan(d)) == 0])\n",
    "data_csv = pd.read_csv(f'{root}/csv/X_num.csv')\n",
    "#train_data = data[train_idx]\n",
    "#test_data = data[test_idx]\n",
    "\n",
    "target = np.load(f'{root}/Y.npy')\n",
    "meta = np.load(f'{root}/X_meta.npy')\n",
    "meta = np.stack([meta[i] for i, d in enumerate(data) if sum(np.isnan(d)) == 0])\n",
    "meta_csv = pd.read_csv(f'{root}/csv/X_meta.csv')\n",
    "\n",
    "names = list(data_csv.columns)\n",
    "names.append('location')\n",
    "data_new = np.concatenate((data, meta[:, -2].reshape(-1, 1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eb1dee44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1653 1578\n"
     ]
    }
   ],
   "source": [
    "#########################################################\n",
    "#--------------- Month/location splitted ---------------- \n",
    "#########################################################\n",
    "scaler = StandardScaler()\n",
    "\n",
    "dict_location_src = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 1.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_src[d[-1]] = []\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "     \n",
    "\n",
    "dict_location_src_new = {}\n",
    "for key in dict_location_src.keys():\n",
    "    item = dict_location_src[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_src_new[key] = item\n",
    "dict_location_src = dict_location_src_new\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "dict_location_trg = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 6.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_trg[d[-1]] = []\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "    \n",
    "dict_location_trg_new = {}\n",
    "for key in dict_location_trg.keys():\n",
    "    item = dict_location_trg[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_trg_new[key] = item\n",
    "dict_location_trg = dict_location_trg_new\n",
    "\n",
    "print(len(dict_location_trg), len(dict_location_src))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c5c56fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################\n",
    "#--------------------- X, Y paired ----------------------\n",
    "#########################################################\n",
    "\n",
    "chosen_locs = list(dict_location_trg.keys())[:200]\n",
    "X_pair_orig, Y_pair_orig = [], []\n",
    "for key in dict_location_src.keys():\n",
    "    if key not in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_pair_orig.append(x)\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_pair_orig.append(item_trg) # sample\n",
    "X_pair_orig = np.stack(X_pair_orig)\n",
    "\n",
    "\n",
    "#########################################################\n",
    "#----------------------- X, Y ---------------------------\n",
    "#########################################################\n",
    "\n",
    "# N x 1 x 2D - src\n",
    "# N x M x D - trg\n",
    " \n",
    "# sampling: \n",
    "# b x 1 x 2D,\n",
    "# b x M x D -> sample -> b x 1 x D\n",
    "\n",
    "X_orig = []\n",
    "for key in dict_location_src.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_orig.append(x)\n",
    "X_orig = np.stack(X_orig)\n",
    "\n",
    "Y_orig = []\n",
    "for key in dict_location_trg.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_orig.append(item_trg) # sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b44ab264",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1386, 188) 1453 (192, 188) 192\n"
     ]
    }
   ],
   "source": [
    "print(X_orig.shape, len(Y_orig), X_pair_orig.shape, len(Y_pair_orig))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b6332b",
   "metadata": {},
   "source": [
    "# Running"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f0a1831a-9e8f-480c-abb1-da3f6dc56ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = X_orig\n",
    "target_data = Y_orig[0]\n",
    "X_DIM = source_data.shape[1]\n",
    "Y_DIM = target_data.shape[1]\n",
    "#X_DIM = data_set[\"features\"].shape[1]\n",
    "#Y_DIM = data_set[\"features\"].shape[1]\n",
    "assert X_DIM > 1\n",
    "assert Y_DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 42\n",
    "\n",
    "N_POTENTIALS = 10\n",
    "M_POTENTIALS = 1 #10\n",
    "EPSILON = 0.01\n",
    "A_DIAGONAL_INIT = 0.5\n",
    "L_PAIRED_SAMPLES = len(X_pair_orig)\n",
    "M_X_UNPAIRED_SAMPLES = 0\n",
    "N_Y_UNPAIRED_SAMPLES = 0\n",
    "\n",
    "BATCH_SIZE = 32\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "D_LR = 3e-4  # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "NUM_LABELED = 10\n",
    "TRAIN_SUBSET_SIZE = 2\n",
    "\n",
    "PLOT_EVERY = 1000\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dc8ded78-a9b3-4053-972e-e0fbed555f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_COST = \"MLP_deep_deep\"\n",
    "EXP_COST_INCLUDED = True\n",
    "EXP_META_INFO = \"\"\n",
    "EXP_NAME = (\n",
    "    f\"Light-GCOT_Batch_Effect_\"\n",
    "    + f\"EPSILON_{EPSILON}_\"\n",
    "    + f\"N_{N_POTENTIALS}_\"\n",
    "    + f\"M_{M_POTENTIALS}_\"\n",
    "    + f\"with_{EXP_COST}_\"\n",
    "    + f\"cost_included_{EXP_COST_INCLUDED}_\"\n",
    "    + f\"N_PAIRED_{NUM_LABELED}_\"\n",
    "    + f\"M_UNPAIRED_{len(source_data)}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    X_DIM=X_DIM,\n",
    "    Y_DIM=Y_DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    M_POTENTIALS=M_POTENTIALS,\n",
    "    A_DIAGONAL_INIT=A_DIAGONAL_INIT,\n",
    "    N_PAIRED_SAMPLES=NUM_LABELED,\n",
    "    M_UNPAIRED_SAMPLES=len(source_data),\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cb4506b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pytorch_total_params = sum(p.numel() for p in D.parameters())\n",
    "#pytorch_total_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d32db3-1805-4e51-b492-e1eaaba959e6",
   "metadata": {},
   "source": [
    "## Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "86eea788",
   "metadata": {},
   "outputs": [],
   "source": [
    "def paired_sampler(X_pair, Y_pair, b_size):\n",
    "    idxs = np.random.randint(low=0, high=len(X_pair)-1, size=b_size)\n",
    "    x_pair_batch = torch.tensor(X_pair[idxs]).to('cuda')\n",
    "    y_pair_batch = np.stack([Y_pair[idx][random.randint(0, len(Y_pair[idx])-1)] for idx in idxs])\n",
    "    y_pair_batch = torch.tensor(y_pair_batch).to('cuda')\n",
    "    return x_pair_batch.to(torch.float32), y_pair_batch.to(torch.float32)\n",
    "\n",
    "def unpaired_sampler(X, Y, b_size):\n",
    "    # UNPAIRED SAMPLER\n",
    "    idxs = np.random.randint(low=0, high=len(X)-1, size=b_size)\n",
    "    idxs_y = np.array([len(X) - idx - 1 for idx in idxs])\n",
    "\n",
    "    x_batch = torch.tensor(X[idxs]).to('cuda')\n",
    "    y_batch = np.stack([Y[idx][random.randint(0, len(Y[idx])-1)] for idx in idxs_y])\n",
    "    y_batch = torch.tensor(y_batch).to('cuda')\n",
    "    return x_batch.to(torch.float32), y_batch.to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "012173da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs.gmm_based.cost import MLPLSECostConfig\n",
    "from configs.gmm_based.train import TrainConfig\n",
    "\n",
    "Y_DIM = 94\n",
    "M_POTENTIALS = 10\n",
    "N_POTENTIALS = 10\n",
    "\n",
    "LOG_V_M_HIDDEN_CHANNELS = [M_POTENTIALS]\n",
    "B_M_HIDDEN_CHANNELS = [M_POTENTIALS * Y_DIM]\n",
    "\n",
    "cost_config = MLPLSECostConfig(\n",
    "    m_potentials=M_POTENTIALS,\n",
    "    log_v_m_hidden_channels=LOG_V_M_HIDDEN_CHANNELS,\n",
    "    b_m_hidden_channels=B_M_HIDDEN_CHANNELS,\n",
    "    x_dim=188, y_dim=94\n",
    ")\n",
    "\n",
    "device = torch.device(f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\")\n",
    "device\n",
    "torch.set_default_device(device)\n",
    "dtype = torch.float64\n",
    "torch.torch.set_default_dtype(dtype)\n",
    "\n",
    "train_config = TrainConfig(\n",
    "    steps_to=8000, paired_batch_size=BATCH_SIZE, unpaired_batch_size=BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "81e92995",
   "metadata": {},
   "outputs": [],
   "source": [
    "    @torch.no_grad()\n",
    "    def forward_test(model, batched_x: torch.Tensor, Y_pair_test) -> torch.Tensor:  # -> [bs]\n",
    "        samples = []\n",
    "        fids = []\n",
    "        fids2 = []\n",
    "        total_probs = []\n",
    "        batch_size = 1\n",
    "        sampling_batch_size = 1\n",
    "\n",
    "        num_sampling_iterations = len(Y_pair_test)\n",
    "        for i in range(num_sampling_iterations):\n",
    "            sub_batch_x = batched_x[sampling_batch_size * i : sampling_batch_size * (i + 1)]\n",
    "            sub_batch_y = torch.tensor(Y_pair_test[i]).to('cuda')\n",
    "            log_w_n = model.log_w_n().to(dtype)\n",
    "            a_n = model.a_n().to(dtype)\n",
    "            A_n = model.A_n().to(dtype)\n",
    "            \n",
    "            cond_distr_paired = model.get_conditional_distribution(sub_batch_x.to(torch.float64),\n",
    "                                                                   log_w_n.to(torch.float64),\n",
    "                                                                   a_n.to(torch.float64),\n",
    "                                                                   A_n.to(torch.float64))\n",
    "            D_loss_paired = -cond_distr_paired.log_prob(sub_batch_y).mean()\n",
    "            \n",
    "            total_probs.append(D_loss_paired.cpu())\n",
    "            \n",
    "            samples = []\n",
    "            for _ in range(len(sub_batch_y)):\n",
    "                samples.append(cond_distr_paired.sample())\n",
    "            fid_samples = np.array(torch.cat(samples, dim=0).cpu())\n",
    "            fid_samples_2 = np.array(sub_batch_y.cpu())\n",
    "            \n",
    "            mu1 = np.mean(fid_samples, axis=0)\n",
    "            sigma1 = np.cov(fid_samples, rowvar=False)\n",
    "            mu2 = np.mean(fid_samples_2, axis=0)\n",
    "            sigma2 = np.cov(fid_samples_2, rowvar=False)\n",
    "            \n",
    "            diff = mu1 - mu2\n",
    "            covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "            tr_covmean = np.trace(covmean)\n",
    "            fid = (diff.dot(diff) + np.trace(sigma1) +  np.trace(sigma2) - 2 * tr_covmean)\n",
    "            fids.append(fid.real)\n",
    "            fids2.append(fid.real / np.var(fid_samples_2))\n",
    "\n",
    "\n",
    "        return samples, np.mean(total_probs), np.mean(fids), np.mean(fids2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "204b5b99-b714-42b4-b1a6-118174f4898d",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "500 500 90 90 100 100\n",
      "Training with number of labeled: 90\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                         | 8/10000 [00:02<40:16,  4.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-458.10428349935637 119.11975650649933 1201.7808238754387\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                  | 1009/10000 [00:21<13:09, 11.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.22818229782202326 8.26295863658546 79.338925248303\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▋                              | 2011/10000 [00:39<15:18,  8.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12.379478089546787 7.555595385412721 72.36015986294215\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███████████▍                          | 3007/10000 [00:56<10:43, 10.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23.124370534084523 7.399682906579766 70.70552076959538\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|███████████████▏                      | 4009/10000 [01:14<07:48, 12.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25.532624895130063 7.36934813219669 70.45657210299092\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|███████████████████                   | 5010/10000 [01:31<06:59, 11.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21.569082889064184 7.202467667617113 68.90034967890072\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████████████████████▊               | 6011/10000 [01:50<05:45, 11.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33.54522557631462 7.196683493841608 68.7275240346769\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████▋           | 7014/10000 [02:08<03:58, 12.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34.542427629196624 7.162836443857995 68.46135678566746\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|██████████████████████████████▍       | 8008/10000 [02:25<02:53, 11.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33.32422445393114 7.228328343750709 69.09230092181564\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|██████████████████████████████████▏   | 9008/10000 [02:43<01:49,  9.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32.60801949368183 7.140584061101573 68.38626798633726\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 10000/10000 [02:58<00:00, 56.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  L_PAIRED_SAMPLES  FOSCTTM_Score       fid       vfid\n",
      "0               90      34.665192  7.118523  68.082497\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/tmp/ipykernel_476081/354169610.py:100: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
      "  results_df = pd.concat([results_df, new_row], ignore_index=True)\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "OUTPUT_SEED = 50\n",
    "random.seed(OUTPUT_SEED)\n",
    "torch.manual_seed(OUTPUT_SEED)\n",
    "np.random.seed(OUTPUT_SEED)\n",
    "loader_kwargs = {\"num_workers\": 0, \"pin_memory\": True, \"generator\": torch.Generator(device='cpu')}\n",
    "\n",
    "def mse(a, b):\n",
    "    l = (a - b) ** 2\n",
    "    return l.mean()\n",
    "\n",
    "#wandb.init(name=EXP_NAME, config=config)\n",
    "results_df = pd.DataFrame(columns=['L_PAIRED_SAMPLES', 'FOSCTTM_Score'])\n",
    "MAX_STEPS = 10000\n",
    "stats = []\n",
    "\n",
    "# Splitting\n",
    "L_PAIRED_SAMPLES = 90\n",
    "L_UNPAIRED_SAMPLES = 500\n",
    "X, Y = X_orig[:L_UNPAIRED_SAMPLES], Y_orig[-L_UNPAIRED_SAMPLES:]\n",
    "X_pair, Y_pair = X_pair_orig[:L_PAIRED_SAMPLES], Y_pair_orig[:L_PAIRED_SAMPLES]\n",
    "X_pair_test, Y_pair_test = X_pair_orig[-100:], Y_pair_orig[-100:]\n",
    "\n",
    "print(len(X), len(Y), len(X_pair), len(Y_pair), len(X_pair_test), len(Y_pair_test))\n",
    "\n",
    "for _ in [0]:\n",
    "    test_size = 100\n",
    "    print(\"Training with number of labeled:\", L_PAIRED_SAMPLES)\n",
    "    \n",
    "    cost = MLPLSECost(**cost_config.model_dump()).to(dtype)\n",
    "\n",
    "    model = GMMEOT(\n",
    "        y_dim=Y_DIM,\n",
    "        n_potentials=N_POTENTIALS,\n",
    "        cost=cost.to(dtype),\n",
    "    ).to(dtype)\n",
    "    model.to('cuda')\n",
    "    \n",
    "    D_opt = torch.optim.Adam(model.parameters(), lr=4e-3)\n",
    "    scheduler = lr_scheduler.StepLR(D_opt, step_size=1000, gamma=0.5) #0.87\n",
    "     \n",
    "    if CONTINUE > -1:\n",
    "        D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_{CONTINUE}.pt\")))\n",
    "        \n",
    "    for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):    \n",
    "        # training loop\n",
    "        D_opt.zero_grad()\n",
    "    \n",
    "        x_batch, y_batch = unpaired_sampler(X, Y, BATCH_SIZE)\n",
    "        x_batch = x_batch.to(dtype)\n",
    "        y_batch = y_batch.to(dtype)\n",
    "        \n",
    "        log_w_n = model.log_w_n().to(dtype)\n",
    "        a_n = model.a_n().to(dtype)\n",
    "        A_n = model.A_n().to(dtype)\n",
    "        cond_distr_unpaired = model.get_conditional_distribution(\n",
    "            x_batch.repeat(train_config.unpaired_batch_size, 1).to(torch.float64), \n",
    "            log_w_n.to(torch.float64), \n",
    "            a_n.to(torch.float64),\n",
    "            A_n.to(torch.float64)\n",
    "        )\n",
    "        \n",
    "        fwd = cond_distr_unpaired.log_prob(y_batch.repeat(train_config.unpaired_batch_size, 1))\n",
    "        D_loss_unpaired = -torch.log(\n",
    "            torch.mean(torch.exp(fwd.reshape(train_config.unpaired_batch_size, train_config.unpaired_batch_size)), dim=-1)\n",
    "        ).mean()\n",
    "    \n",
    "        if EXP_COST_INCLUDED:\n",
    "            x_pair_batch, y_pair_batch = paired_sampler(X_pair, Y_pair, BATCH_SIZE)\n",
    "            cond_distr_paired = model.get_conditional_distribution(x_pair_batch.to(torch.float64),\n",
    "                                                                   log_w_n.to(torch.float64),\n",
    "                                                                   a_n.to(torch.float64),\n",
    "                                                                   A_n.to(torch.float64))\n",
    "            D_loss_paired = -cond_distr_paired.log_prob(y_pair_batch).mean()\n",
    "\n",
    "            D_loss = D_loss_unpaired + D_loss_paired\n",
    "            D_loss.backward()\n",
    "            D_opt.step()\n",
    "        \n",
    "        \n",
    "        if step % 1000 == 0:\n",
    "            translated, total_probs, fids, fids2 = forward_test(model, \n",
    "                                           torch.tensor(X_pair_test).to('cuda').to(torch.float64),\n",
    "                                            Y_pair_test,)\n",
    "            print(-total_probs, fids, fids2)\n",
    "\n",
    "    translated, total_probs, fids, fids2 = forward_test(model, \n",
    "                                           torch.tensor(X_pair_test).to('cuda').to(torch.float64),\n",
    "                                            Y_pair_test,)\n",
    "    foscttm_score = -total_probs\n",
    "    \n",
    "    new_row = pd.DataFrame({\n",
    "        'L_PAIRED_SAMPLES': [L_PAIRED_SAMPLES],\n",
    "        'FOSCTTM_Score': [foscttm_score],\n",
    "        'fid': [fids],\n",
    "        'vfid': [fids2]\n",
    "    })\n",
    "    results_df = pd.concat([results_df, new_row], ignore_index=True)\n",
    "print(results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "44671c00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.32000000000000006 0.0282842712474619\n",
      "7.176666666666667 0.0740870359029763\n",
      "68.33333333333333 0.4714045207910317\n"
     ]
    }
   ],
   "source": [
    "a = [0.34, 0.28, 0.34]\n",
    "b = [7.14, 7.28, 7.11]\n",
    "c = [68, 69, 68]\n",
    "print(np.mean(a), np.std(a))\n",
    "print(np.mean(b), np.std(b))\n",
    "print(np.mean(c), np.std(c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "838f932e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#     |  Ours         | cGAn          | uGAn         |  CNF          |  Regres.\n",
    "# FID | 7.21 +- 0.04  | 15.79 +- 1.11 | 15.44 +- 1.89| 18.72 +- 0.09 | 8.29 +- 0.044\n",
    "# vFID| 72 +- 1       | 156 +- 11     | 152 +- 19    | 184 +- 1      | 81.0 +- 0.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf4b1f75",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
