{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e1bc3506-f509-4b37-8179-15e7a944c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "sys.path.append('/home/quickjkee/projects/Light-GCOT')\n",
    "\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "from torch.distributions.normal import Normal \n",
    "from sklearn.preprocessing import Normalizer\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "import moscot.plotting as mtp\n",
    "import scipy\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from moscot import datasets\n",
    "from moscot.problems.cross_modality import TranslationProblem\n",
    "from sklearn import preprocessing as pp\n",
    "\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.from_loader import PairedLoaderSampler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.models.light_gcot import LightGCOT\n",
    "from sklearn.manifold import TSNE\n",
    "tsne = TSNE(n_components=2, random_state=50)\n",
    "\n",
    "#https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9a6ad3",
   "metadata": {},
   "source": [
    "# Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2366fd9",
   "metadata": {},
   "source": [
    "### PS\n",
    "1) Source  \n",
    "$X \\in \\mathbb{R}^{N \\times d_{1}}, N - \\text{number of locations}, d_{1} - \\text{features dim}$ \\\n",
    "$x = (\\mu, \\sigma) - \\text{for a given location in June}$ \\\n",
    "$N = 1396, d_{1} = 188$ \n",
    "\n",
    "2) $Y \\in \\mathbb{R}^{N \\times M \\times d_{2}}, N - \\text{number of locations}, M - \\text{measurements for a given location in January by day}$ \\\n",
    "$M = [1, 31], d_{2} = 94$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02965c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################################\n",
    "#-------------- RAW DATA -----------------\n",
    "##########################################\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "root = '../tabred/kal/weather'\n",
    "\n",
    "data = np.load(f'{root}/X_num.npy')\n",
    "data = np.stack([d for d in data if sum(np.isnan(d)) == 0])\n",
    "data_csv = pd.read_csv(f'{root}/csv/X_num.csv')\n",
    "#train_data = data[train_idx]\n",
    "#test_data = data[test_idx]\n",
    "\n",
    "target = np.load(f'{root}/Y.npy')\n",
    "meta = np.load(f'{root}/X_meta.npy')\n",
    "meta = np.stack([meta[i] for i, d in enumerate(data) if sum(np.isnan(d)) == 0])\n",
    "meta_csv = pd.read_csv(f'{root}/csv/X_meta.csv')\n",
    "\n",
    "names = list(data_csv.columns)\n",
    "names.append('location')\n",
    "data_new = np.concatenate((data, meta[:, -2].reshape(-1, 1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb1dee44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1653 1578\n"
     ]
    }
   ],
   "source": [
    "#########################################################\n",
    "#--------------- Month/location splitted ---------------- \n",
    "#########################################################\n",
    "scaler = StandardScaler()\n",
    "\n",
    "dict_location_src = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 1.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_src[d[-1]] = []\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "     \n",
    "\n",
    "dict_location_src_new = {}\n",
    "for key in dict_location_src.keys():\n",
    "    item = dict_location_src[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_src_new[key] = item\n",
    "dict_location_src = dict_location_src_new\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "dict_location_trg = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 6.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_trg[d[-1]] = []\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "    \n",
    "dict_location_trg_new = {}\n",
    "for key in dict_location_trg.keys():\n",
    "    item = dict_location_trg[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_trg_new[key] = item\n",
    "dict_location_trg = dict_location_trg_new\n",
    "\n",
    "print(len(dict_location_trg), len(dict_location_src))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c5c56fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################\n",
    "#--------------------- X, Y paired ----------------------\n",
    "#########################################################\n",
    "\n",
    "chosen_locs = list(dict_location_trg.keys())[:200]\n",
    "X_pair_orig, Y_pair_orig = [], []\n",
    "for key in dict_location_src.keys():\n",
    "    if key not in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_pair_orig.append(x)\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_pair_orig.append(item_trg) # sample\n",
    "X_pair_orig = np.stack(X_pair_orig)\n",
    "\n",
    "\n",
    "#########################################################\n",
    "#----------------------- X, Y ---------------------------\n",
    "#########################################################\n",
    "\n",
    "# N x 1 x 2D - src\n",
    "# N x M x D - trg\n",
    " \n",
    "# sampling: \n",
    "# b x 1 x 2D,\n",
    "# b x M x D -> sample -> b x 1 x D\n",
    "\n",
    "X_orig = []\n",
    "for key in dict_location_src.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_orig.append(x)\n",
    "X_orig = np.stack(X_orig)\n",
    "\n",
    "Y_orig = []\n",
    "for key in dict_location_trg.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_orig.append(item_trg) # sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b44ab264",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1386, 188) 1453 (192, 188) 192\n"
     ]
    }
   ],
   "source": [
    "print(X_orig.shape, len(Y_orig), X_pair_orig.shape, len(Y_pair_orig))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b6332b",
   "metadata": {},
   "source": [
    "# Running"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f0a1831a-9e8f-480c-abb1-da3f6dc56ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = X_orig\n",
    "target_data = Y_orig[0]\n",
    "X_DIM = source_data.shape[1]\n",
    "Y_DIM = target_data.shape[1]\n",
    "#X_DIM = data_set[\"features\"].shape[1]\n",
    "#Y_DIM = data_set[\"features\"].shape[1]\n",
    "assert X_DIM > 1\n",
    "assert Y_DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 42\n",
    "\n",
    "N_POTENTIALS = 10\n",
    "M_POTENTIALS = 1 #10\n",
    "EPSILON = 1\n",
    "A_DIAGONAL_INIT = 0.5\n",
    "L_PAIRED_SAMPLES = len(X_pair_orig)\n",
    "M_X_UNPAIRED_SAMPLES = 0\n",
    "N_Y_UNPAIRED_SAMPLES = 0\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "D_LR = 3e-4  # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "NUM_LABELED = 10\n",
    "TRAIN_SUBSET_SIZE = 2\n",
    "\n",
    "PLOT_EVERY = 1000\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "dc8ded78-a9b3-4053-972e-e0fbed555f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_COST = \"MLP_deep_deep\"\n",
    "EXP_COST_INCLUDED = True\n",
    "EXP_META_INFO = \"\"\n",
    "EXP_NAME = (\n",
    "    f\"Light-GCOT_Batch_Effect_\"\n",
    "    + f\"EPSILON_{EPSILON}_\"\n",
    "    + f\"N_{N_POTENTIALS}_\"\n",
    "    + f\"M_{M_POTENTIALS}_\"\n",
    "    + f\"with_{EXP_COST}_\"\n",
    "    + f\"cost_included_{EXP_COST_INCLUDED}_\"\n",
    "    + f\"N_PAIRED_{NUM_LABELED}_\"\n",
    "    + f\"M_UNPAIRED_{len(source_data)}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    X_DIM=X_DIM,\n",
    "    Y_DIM=Y_DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    M_POTENTIALS=M_POTENTIALS,\n",
    "    A_DIAGONAL_INIT=A_DIAGONAL_INIT,\n",
    "    N_PAIRED_SAMPLES=NUM_LABELED,\n",
    "    M_UNPAIRED_SAMPLES=len(source_data),\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cb4506b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pytorch_total_params = sum(p.numel() for p in D.parameters())\n",
    "#pytorch_total_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d32db3-1805-4e51-b492-e1eaaba959e6",
   "metadata": {},
   "source": [
    "## Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "86eea788",
   "metadata": {},
   "outputs": [],
   "source": [
    "def paired_sampler(X_pair, Y_pair, b_size):\n",
    "    idxs = np.random.randint(low=0, high=len(X_pair)-1, size=b_size)\n",
    "    x_pair_batch = torch.tensor(X_pair[idxs]).to('cuda')\n",
    "    y_pair_batch = np.stack([Y_pair[idx][random.randint(0, len(Y_pair[idx])-1)] for idx in idxs])\n",
    "    y_pair_batch = torch.tensor(y_pair_batch).to('cuda')\n",
    "    return x_pair_batch.to(torch.float32), y_pair_batch.to(torch.float32)\n",
    "\n",
    "def unpaired_sampler(X, Y, b_size):\n",
    "    # UNPAIRED SAMPLER\n",
    "    idxs = np.random.randint(low=0, high=len(X)-1, size=b_size)\n",
    "    idxs_y = np.array([len(X) - idx - 1 for idx in idxs])\n",
    "\n",
    "    x_batch = torch.tensor(X[idxs]).to('cuda')\n",
    "    y_batch = np.stack([Y[idx][random.randint(0, len(Y[idx])-1)] for idx in idxs_y])\n",
    "    y_batch = torch.tensor(y_batch).to('cuda')\n",
    "    return x_batch.to(torch.float32), y_batch.to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7d300dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models.models import MyCDiscriminator, MyCGenerator\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "import torch.nn.functional as F\n",
    "from scipy import linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0f272ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "nz = 10\n",
    "netG = MyCGenerator(\n",
    "        x_dim=X_DIM,\n",
    "        t_dim=2,\n",
    "        n_t=1,\n",
    "        out_dim=Y_DIM,\n",
    "        z_dim=nz,\n",
    "        layers=[256, 256, 256],\n",
    "    ).to('cuda')\n",
    "\n",
    "netD = MyCDiscriminator(x_dim=141, \n",
    "                        t_dim=2, \n",
    "                        n_t=1, \n",
    "                        layers=[256, 256, 256]).to('cuda')\n",
    "optimizerD = torch.optim.Adam(netD.parameters())\n",
    "optimizerG = torch.optim.Adam(netG.parameters())\n",
    "\n",
    "schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, 5000, eta_min=1e-5)\n",
    "schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, 5000, eta_min=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2131aaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "history = {\n",
    "        \"D_loss\": [],\n",
    "        \"G_loss\": [],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "f5391bb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<>:46: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n",
      "<>:46: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n",
      "/var/tmp/ipykernel_80809/2381323863.py:46: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n",
      "  if 1 is None or step % 1 == 0:\n",
      "  0%|                                                 | 0/10000 [00:00<?, ?it/s]/var/tmp/ipykernel_80809/2381323863.py:138: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  loc = torch.tensor(samples.cpu()).mean(dim=0)\n",
      "/var/tmp/ipykernel_80809/2381323863.py:139: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  scale = torch.tensor(samples.cpu()).std(dim=0)\n",
      "  0%|                                        | 11/10000 [00:01<15:44, 10.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "157763539.3492658\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                  | 1011/10000 [00:12<05:15, 28.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "80024517.47125977\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▋                              | 2015/10000 [00:24<04:11, 31.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "54640010.88279938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███████████▍                          | 3017/10000 [00:35<03:42, 31.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "43133209.1237389\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|███████████████▎                      | 4019/10000 [00:47<03:10, 31.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "35860098.03935559\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|███████████████████                   | 5010/10000 [00:58<03:28, 23.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31006797.81325001\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████████████████████▉               | 6020/10000 [01:10<02:09, 30.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27276569.19902625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████▋           | 7019/10000 [01:22<01:42, 29.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26261945.0222122\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|██████████████████████████████▍       | 8011/10000 [01:33<01:06, 29.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23925506.725139845\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|██████████████████████████████████▏   | 9013/10000 [01:45<00:33, 29.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22124866.535106517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 10000/10000 [01:55<00:00, 86.30it/s]\n"
     ]
    }
   ],
   "source": [
    "MAX_STEPS = 10000\n",
    "stats = []\n",
    "D_loss = []\n",
    "fids = []\n",
    "fids2 = []\n",
    "device = 'cuda'\n",
    "\n",
    "# Splitting\n",
    "L_PAIRED_SAMPLES = 90\n",
    "L_UNPAIRED_SAMPLES = 500\n",
    "X, Y = X_orig[:L_UNPAIRED_SAMPLES], Y_orig[-L_UNPAIRED_SAMPLES:]\n",
    "X_pair, Y_pair = X_pair_orig[:L_PAIRED_SAMPLES], Y_pair_orig[:L_PAIRED_SAMPLES]\n",
    "X_pair_test, Y_pair_test = X_pair_orig[-100:], Y_pair_orig[-100:]\n",
    "\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "        #########################\n",
    "        # Discriminator training\n",
    "        #########################\n",
    "        for p in netD.parameters():\n",
    "            p.requires_grad = True\n",
    "\n",
    "        netD.zero_grad()\n",
    "\n",
    "        ###################################\n",
    "        # Sample real data\n",
    "        X_paired, Y_paired = paired_sampler(X_pair, Y_pair, BATCH_SIZE)\n",
    "\n",
    "        ###################################\n",
    "        # Sample timesteps\n",
    "        t = torch.randint(0, 1, (X_paired.size(0),), device=device)\n",
    "\n",
    "        Y_paired.requires_grad = True\n",
    "\n",
    "        ###################################\n",
    "        # Optimizing loss on real data\n",
    "        D_real = netD(Y_paired, t, X_paired.detach())\n",
    "\n",
    "        errD_real = F.softplus(-D_real)\n",
    "        errD_real = errD_real.mean()\n",
    "\n",
    "        errD_real.backward(retain_graph=True)\n",
    "\n",
    "        ###################################\n",
    "        # R_1(\\phi) regularization\n",
    "        if 1 is None or step % 1 == 0:\n",
    "            grad_real = torch.autograd.grad(\n",
    "                outputs=D_real.sum(),\n",
    "                inputs=Y_paired,\n",
    "                create_graph=True,\n",
    "            )[0]\n",
    "            grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()\n",
    "\n",
    "            grad_penalty = 0.01 / 2 * grad_penalty\n",
    "            grad_penalty.backward()\n",
    "\n",
    "        ###################################\n",
    "        # Sample vector from latent space for generation\n",
    "        latent_z = torch.randn(BATCH_SIZE, nz, device=device)\n",
    "\n",
    "        ###################################\n",
    "        # Sample fake output\n",
    "        x_0_predict = netG(X_paired.detach(), t, latent_z)\n",
    "\n",
    "        ###################################\n",
    "        # Optimize loss on fake data\n",
    "        output = netD(x_0_predict, t, X_paired.detach()).view(-1)\n",
    "\n",
    "        errD_fake = F.softplus(output)\n",
    "        errD_fake = errD_fake.mean()\n",
    "        errD_fake.backward()\n",
    "\n",
    "        errD = errD_real + errD_fake\n",
    "\n",
    "        D_loss.append(errD.item())\n",
    "        #print(f'D Loss {errD.item()}')\n",
    "\n",
    "        ###################################\n",
    "        # Update weights of netD\n",
    "        optimizerD.step()\n",
    "\n",
    "        #############################################################\n",
    "\n",
    "        #########################\n",
    "        # Generator training\n",
    "        #########################\n",
    "        for p in netD.parameters():\n",
    "            p.requires_grad = False\n",
    "        netG.zero_grad()\n",
    "\n",
    "        ###################################\n",
    "        # Sample timesteps\n",
    "        t = torch.randint(0, 1, (X_paired.size(0),), device=device)\n",
    "\n",
    "        ###################################\n",
    "        # Sample pairs for training\n",
    "        unp_sample, _ = unpaired_sampler(X, Y, BATCH_SIZE)\n",
    "\n",
    "        ###################################\n",
    "        # Sample vector from latent space for generation\n",
    "        latent_z = torch.randn(BATCH_SIZE, nz, device=device)\n",
    "\n",
    "        ###################################\n",
    "        # Sample fake output\n",
    "        x_0_predict = netG(unp_sample.detach(), t, latent_z)\n",
    "\n",
    "        ###################################\n",
    "        # Optimize loss on fake data\n",
    "        output = netD(x_0_predict, t, unp_sample.detach()).view(-1)\n",
    "\n",
    "        ###################################\n",
    "        # Update weights of netG\n",
    "        errG = F.softplus(-output)\n",
    "        errG = errG.mean()\n",
    "\n",
    "        errG.backward()\n",
    "        optimizerG.step()\n",
    "\n",
    "        #print(f'G Loss {errG.item()}')\n",
    "        # LR-Scheduling step\n",
    "        schedulerG.step()\n",
    "        schedulerD.step()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            if step % 1000 == 0:\n",
    "                for x, y in zip(X_pair_test, Y_pair_test):\n",
    "                    x = torch.tensor(x).unsqueeze(0).to(device)\n",
    "                    y = torch.tensor(y).to(device)\n",
    "                    samples = []\n",
    "                    for _ in range(len(y)):\n",
    "                        latent_z = torch.randn(len(x), nz, device=device)\n",
    "                        t = torch.randint(0, 1, (latent_z.size(0),), device=device)\n",
    "                        sample = netG(x.detach().float(), t, latent_z)\n",
    "                        samples.append(sample.cpu())\n",
    "                    \n",
    "                    try:\n",
    "                        samples = torch.cat(samples, dim=0).cpu()\n",
    "                        loc = torch.tensor(samples.cpu()).mean(dim=0)\n",
    "                        scale = torch.tensor(samples.cpu()).std(dim=0)\n",
    "                        loss = -Normal(loc, scale).log_prob(y.cpu()).sum()\n",
    "                        fids.append(loss)\n",
    "                    except ValueError:\n",
    "                        continue\n",
    "                    \n",
    "                    \"\"\"\n",
    "                    fid_samples = np.array(torch.cat(samples, dim=0).cpu())\n",
    "                    fid_samples_2 = np.array(y.cpu())\n",
    "\n",
    "                    mu1 = np.mean(fid_samples, axis=0)\n",
    "                    sigma1 = np.cov(fid_samples, rowvar=False)\n",
    "                    mu2 = np.mean(fid_samples_2, axis=0)\n",
    "                    sigma2 = np.cov(fid_samples_2, rowvar=False)\n",
    "\n",
    "                    diff = mu1 - mu2\n",
    "                    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "                    tr_covmean = np.trace(covmean)\n",
    "                    fid = (diff.dot(diff) + np.trace(sigma1) +  np.trace(sigma2) - 2 * tr_covmean)\n",
    "                    fids.append(fid.real)\n",
    "                    fids2.append(fid.real / np.var(fid_samples_2))\n",
    "                    \"\"\"\n",
    "                    \n",
    "                print(np.mean(fids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8d146def",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-8026.0, 550.0)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = [-7476, -8576] \n",
    "np.mean(a), np.std(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f41063b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6e5e5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f87f12bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82bd73f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efad248f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
