{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e1bc3506-f509-4b37-8179-15e7a944c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "sys.path.append('/home/quickjkee/projects/Light-GCOT')\n",
    "\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "from sklearn.preprocessing import Normalizer\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "import moscot.plotting as mtp\n",
    "import scipy\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from moscot import datasets\n",
    "from moscot.problems.cross_modality import TranslationProblem\n",
    "from sklearn import preprocessing as pp\n",
    "\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.from_loader import PairedLoaderSampler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.models.light_gcot import LightGCOT\n",
    "from sklearn.manifold import TSNE\n",
    "tsne = TSNE(n_components=2, random_state=50)\n",
    "\n",
    "#https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9a6ad3",
   "metadata": {},
   "source": [
    "# Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2366fd9",
   "metadata": {},
   "source": [
    "### PS\n",
    "1) Source  \n",
    "$X \\in \\mathbb{R}^{N \\times d_{1}}, N - \\text{number of locations}, d_{1} - \\text{features dim}$ \\\n",
    "$x = (\\mu, \\sigma) - \\text{for a given location in June}$ \\\n",
    "$N = 1396, d_{1} = 188$ \n",
    "\n",
    "2) $Y \\in \\mathbb{R}^{N \\times M \\times d_{2}}, N - \\text{number of locations}, M - \\text{measurements for a given location in January by day}$ \\\n",
    "$M = [1, 31], d_{2} = 94$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02965c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################################\n",
    "#-------------- RAW DATA -----------------\n",
    "##########################################\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "root = '../tabred/kal/weather'\n",
    "\n",
    "data = np.load(f'{root}/X_num.npy')\n",
    "data = np.stack([d for d in data if sum(np.isnan(d)) == 0])\n",
    "data_csv = pd.read_csv(f'{root}/csv/X_num.csv')\n",
    "#train_data = data[train_idx]\n",
    "#test_data = data[test_idx]\n",
    "\n",
    "target = np.load(f'{root}/Y.npy')\n",
    "meta = np.load(f'{root}/X_meta.npy')\n",
    "meta = np.stack([meta[i] for i, d in enumerate(data) if sum(np.isnan(d)) == 0])\n",
    "meta_csv = pd.read_csv(f'{root}/csv/X_meta.csv')\n",
    "\n",
    "names = list(data_csv.columns)\n",
    "names.append('location')\n",
    "data_new = np.concatenate((data, meta[:, -2].reshape(-1, 1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb1dee44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1653 1578\n"
     ]
    }
   ],
   "source": [
    "#########################################################\n",
    "#--------------- Month/location splitted ---------------- \n",
    "#########################################################\n",
    "scaler = StandardScaler()\n",
    "\n",
    "dict_location_src = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 1.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_src[d[-1]] = []\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "     \n",
    "\n",
    "dict_location_src_new = {}\n",
    "for key in dict_location_src.keys():\n",
    "    item = dict_location_src[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_src_new[key] = item\n",
    "dict_location_src = dict_location_src_new\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "dict_location_trg = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 6.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_trg[d[-1]] = []\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "    \n",
    "dict_location_trg_new = {}\n",
    "for key in dict_location_trg.keys():\n",
    "    item = dict_location_trg[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_trg_new[key] = item\n",
    "dict_location_trg = dict_location_trg_new\n",
    "\n",
    "print(len(dict_location_trg), len(dict_location_src))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c5c56fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################\n",
    "#--------------------- X, Y paired ----------------------\n",
    "#########################################################\n",
    "\n",
    "chosen_locs = list(dict_location_trg.keys())[:200]\n",
    "X_pair_orig, Y_pair_orig = [], []\n",
    "for key in dict_location_src.keys():\n",
    "    if key not in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_pair_orig.append(x)\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_pair_orig.append(item_trg) # sample\n",
    "X_pair_orig = np.stack(X_pair_orig)\n",
    "\n",
    "\n",
    "#########################################################\n",
    "#----------------------- X, Y ---------------------------\n",
    "#########################################################\n",
    "\n",
    "# N x 1 x 2D - src\n",
    "# N x M x D - trg\n",
    " \n",
    "# sampling: \n",
    "# b x 1 x 2D,\n",
    "# b x M x D -> sample -> b x 1 x D\n",
    "\n",
    "X_orig = []\n",
    "for key in dict_location_src.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_orig.append(x)\n",
    "X_orig = np.stack(X_orig)\n",
    "\n",
    "Y_orig = []\n",
    "for key in dict_location_trg.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_orig.append(item_trg) # sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b44ab264",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1386, 188) 1453 (192, 188) 192\n"
     ]
    }
   ],
   "source": [
    "print(X_orig.shape, len(Y_orig), X_pair_orig.shape, len(Y_pair_orig))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b6332b",
   "metadata": {},
   "source": [
    "# Running"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f0a1831a-9e8f-480c-abb1-da3f6dc56ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = X_orig\n",
    "target_data = Y_orig[0]\n",
    "X_DIM = source_data.shape[1]\n",
    "Y_DIM = target_data.shape[1]\n",
    "#X_DIM = data_set[\"features\"].shape[1]\n",
    "#Y_DIM = data_set[\"features\"].shape[1]\n",
    "assert X_DIM > 1\n",
    "assert Y_DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 50\n",
    "\n",
    "N_POTENTIALS = 10\n",
    "M_POTENTIALS = 1 #10\n",
    "EPSILON = 1\n",
    "A_DIAGONAL_INIT = 0.5\n",
    "L_PAIRED_SAMPLES = len(X_pair_orig)\n",
    "M_X_UNPAIRED_SAMPLES = 0\n",
    "N_Y_UNPAIRED_SAMPLES = 0\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "D_LR = 3e-4  # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "NUM_LABELED = 10\n",
    "TRAIN_SUBSET_SIZE = 2\n",
    "\n",
    "PLOT_EVERY = 1000\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "dc8ded78-a9b3-4053-972e-e0fbed555f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_COST = \"MLP_deep_deep\"\n",
    "EXP_COST_INCLUDED = True\n",
    "EXP_META_INFO = \"\"\n",
    "EXP_NAME = (\n",
    "    f\"Light-GCOT_Batch_Effect_\"\n",
    "    + f\"EPSILON_{EPSILON}_\"\n",
    "    + f\"N_{N_POTENTIALS}_\"\n",
    "    + f\"M_{M_POTENTIALS}_\"\n",
    "    + f\"with_{EXP_COST}_\"\n",
    "    + f\"cost_included_{EXP_COST_INCLUDED}_\"\n",
    "    + f\"N_PAIRED_{NUM_LABELED}_\"\n",
    "    + f\"M_UNPAIRED_{len(source_data)}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    X_DIM=X_DIM,\n",
    "    Y_DIM=Y_DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    M_POTENTIALS=M_POTENTIALS,\n",
    "    A_DIAGONAL_INIT=A_DIAGONAL_INIT,\n",
    "    N_PAIRED_SAMPLES=NUM_LABELED,\n",
    "    M_UNPAIRED_SAMPLES=len(source_data),\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cb4506b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pytorch_total_params = sum(p.numel() for p in D.parameters())\n",
    "#pytorch_total_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d32db3-1805-4e51-b492-e1eaaba959e6",
   "metadata": {},
   "source": [
    "## Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "86eea788",
   "metadata": {},
   "outputs": [],
   "source": [
    "def paired_sampler(X_pair, Y_pair, b_size):\n",
    "    idxs = np.random.randint(low=0, high=len(X_pair)-1, size=b_size)\n",
    "    x_pair_batch = torch.tensor(X_pair[idxs]).to('cuda')\n",
    "    y_pair_batch = np.stack([Y_pair[idx][random.randint(0, len(Y_pair[idx])-1)] for idx in idxs])\n",
    "    y_pair_batch = torch.tensor(y_pair_batch).to('cuda')\n",
    "    return x_pair_batch.to(torch.float32), y_pair_batch.to(torch.float32)\n",
    "\n",
    "def unpaired_sampler(X, Y, b_size):\n",
    "    # UNPAIRED SAMPLER\n",
    "    idxs = np.random.randint(low=0, high=len(X)-1, size=b_size)\n",
    "    idxs_y = np.array([len(X) - idx - 1 for idx in idxs])\n",
    "\n",
    "    x_batch = torch.tensor(X[idxs]).to('cuda')\n",
    "    y_batch = np.stack([Y[idx][random.randint(0, len(Y[idx])-1)] for idx in idxs_y])\n",
    "    y_batch = torch.tensor(y_batch).to('cuda')\n",
    "    return x_batch.to(torch.float32), y_batch.to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7d300dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models.models import MyDiscriminator, MyGenerator\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "import torch.nn.functional as F\n",
    "from scipy import linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "0f272ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "nz = 10\n",
    "netG = MyGenerator(\n",
    "        x_dim=X_DIM,\n",
    "        out_dim=Y_DIM,\n",
    "        z_dim=nz,\n",
    "        layers=[256, 256, 256],\n",
    "    ).to('cuda')\n",
    "\n",
    "netD = MyDiscriminator(x_dim=Y_DIM, \n",
    "                        layers=[256, 256, 256]).to('cuda')\n",
    "optimizerD = torch.optim.Adam(netD.parameters())\n",
    "optimizerG = torch.optim.Adam(netG.parameters())\n",
    "\n",
    "schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, 5000, eta_min=1e-5)\n",
    "schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, 5000, eta_min=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f7dd7b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "2131aaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "history = {\n",
    "        \"D_loss\": [],\n",
    "        \"G_loss\": [],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "f5391bb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<>:42: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n",
      "<>:42: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n",
      "/var/tmp/ipykernel_53469/2634283775.py:42: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n",
      "  if 1 is None or step % 1 == 0:\n",
      "  0%|                                                 | 0/10000 [00:00<?, ?it/s]/var/tmp/ipykernel_53469/2634283775.py:134: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  loc = torch.tensor(samples.cpu()).mean(dim=0)\n",
      "/var/tmp/ipykernel_53469/2634283775.py:135: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  scale = torch.tensor(samples.cpu()).std(dim=0)\n",
      "  0%|                                        | 11/10000 [00:01<13:44, 12.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8634169.508643506\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                  | 1013/10000 [00:13<05:06, 29.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5600019.874016167\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▋                              | 2016/10000 [00:25<04:35, 28.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4174097.6236821334\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███████████▍                          | 3011/10000 [00:38<04:10, 27.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4104029.0933012725\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|███████████████▏                      | 4013/10000 [00:50<03:27, 28.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3973471.3155766325\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|███████████████████                   | 5014/10000 [01:02<02:52, 28.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3864315.4599917093\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████████████████████▊               | 6011/10000 [01:14<02:13, 29.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3642760.5593549428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████▋           | 7014/10000 [01:27<01:48, 27.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4528568.900386699\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|██████████████████████████████▍       | 8008/10000 [01:39<01:35, 20.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4327468.715367576\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|██████████████████████████████████▏   | 9012/10000 [01:51<00:34, 28.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4823665.808403021\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 10000/10000 [02:02<00:00, 81.41it/s]\n"
     ]
    }
   ],
   "source": [
    "MAX_STEPS = 10000\n",
    "from torch.distributions.normal import Normal \n",
    "stats = []\n",
    "D_loss = []\n",
    "fids = []\n",
    "fids2 = []\n",
    "device = 'cuda'\n",
    "\n",
    "# Splitting\n",
    "L_PAIRED_SAMPLES = 90\n",
    "L_UNPAIRED_SAMPLES = 500\n",
    "X, Y = X_orig[:L_UNPAIRED_SAMPLES], Y_orig[-L_UNPAIRED_SAMPLES:]\n",
    "X_pair, Y_pair = X_pair_orig[:L_PAIRED_SAMPLES], Y_pair_orig[:L_PAIRED_SAMPLES]\n",
    "X_pair_test, Y_pair_test = X_pair_orig[-100:], Y_pair_orig[-100:]\n",
    "\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "        #########################\n",
    "        # Discriminator training\n",
    "        #########################\n",
    "        for p in netD.parameters():\n",
    "            p.requires_grad = True\n",
    "\n",
    "        netD.zero_grad()\n",
    "\n",
    "        ###################################\n",
    "        # Sample real data\n",
    "        X_unpaired, Y_unpaired = unpaired_sampler(X, Y, BATCH_SIZE)\n",
    "        Y_unpaired.requires_grad = True\n",
    "\n",
    "        ###################################\n",
    "        # Optimizing loss on real data\n",
    "        D_real = netD(Y_unpaired)\n",
    "\n",
    "        errD_real = F.softplus(-D_real)\n",
    "        errD_real = errD_real.mean()\n",
    "\n",
    "        errD_real.backward(retain_graph=True)\n",
    "\n",
    "        ###################################\n",
    "        # R_1(\\phi) regularization\n",
    "        if 1 is None or step % 1 == 0:\n",
    "            grad_real = torch.autograd.grad(\n",
    "                outputs=D_real.sum(),\n",
    "                inputs=Y_unpaired,\n",
    "                create_graph=True,\n",
    "            )[0]\n",
    "            grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()\n",
    "\n",
    "            grad_penalty = 0.01 / 2 * grad_penalty\n",
    "            grad_penalty.backward()\n",
    "\n",
    "        ###################################\n",
    "        # Sample vector from latent space for generation\n",
    "        latent_z = torch.randn(BATCH_SIZE, nz, device=device)\n",
    "\n",
    "        ###################################\n",
    "        # Sample fake output\n",
    "        x_0_predict = netG(X_unpaired.detach(), latent_z)\n",
    "\n",
    "        ###################################\n",
    "        # Optimize loss on fake data\n",
    "        output = netD(x_0_predict).view(-1)\n",
    "\n",
    "        errD_fake = F.softplus(output)\n",
    "        errD_fake = errD_fake.mean()\n",
    "        errD_fake.backward()\n",
    "\n",
    "        errD = errD_real + errD_fake\n",
    "\n",
    "        D_loss.append(errD.item())\n",
    "        #print(f'D Loss {errD.item()}')\n",
    "\n",
    "        ###################################\n",
    "        # Update weights of netD\n",
    "        optimizerD.step()\n",
    "\n",
    "        #############################################################\n",
    "\n",
    "        #########################\n",
    "        # Generator training\n",
    "        #########################\n",
    "        for p in netD.parameters():\n",
    "            p.requires_grad = False\n",
    "        netG.zero_grad()\n",
    "        \n",
    "        ###################################\n",
    "        # Sample pairs for training\n",
    "        unp_sample, _ = unpaired_sampler(X, Y, BATCH_SIZE)\n",
    "        X_paired, Y_paired = paired_sampler(X_pair, Y_pair, BATCH_SIZE)\n",
    "\n",
    "        ###################################\n",
    "        # Sample vector from latent space for generation\n",
    "        latent_z = torch.randn(BATCH_SIZE, nz, device=device)\n",
    "        latent_z0 = torch.randn(BATCH_SIZE, nz, device=device)\n",
    "\n",
    "        ###################################\n",
    "        # Sample fake output\n",
    "        x_paired_predict = netG(X_paired.detach(), latent_z)\n",
    "        x_unp_predict = netG(unp_sample.detach(), latent_z0)\n",
    "\n",
    "        ###################################\n",
    "        # Optimize loss on fake data\n",
    "        output = netD(x_unp_predict).view(-1)\n",
    "\n",
    "        ###################################\n",
    "        # Update weights of netG\n",
    "        errG = F.softplus(-output)\n",
    "        errG_mse = F.mse_loss(Y_paired, x_paired_predict)\n",
    "        errG = (errG + errG_mse).mean()\n",
    "\n",
    "        errG.backward()\n",
    "        optimizerG.step()\n",
    "\n",
    "        #print(f'G Loss {errG.item()}')\n",
    "        # LR-Scheduling step\n",
    "        schedulerG.step()\n",
    "        schedulerD.step()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            if step % 1000 == 0:\n",
    "                for x, y in zip(X_pair_test, Y_pair_test):\n",
    "                    x = torch.tensor(x).unsqueeze(0).to(device)\n",
    "                    y = torch.tensor(y).to(device)\n",
    "                    samples = []\n",
    "                    for _ in range(len(y)):\n",
    "                        latent_z = torch.randn(len(x), nz, device=device)\n",
    "                        t = torch.randint(0, 1, (latent_z.size(0),), device=device)\n",
    "                        sample = netG(x.detach().float(), latent_z)\n",
    "                        samples.append(sample.cpu())\n",
    "                    \n",
    "                    try:\n",
    "                        samples = torch.cat(samples, dim=0).cpu()\n",
    "                        loc = torch.tensor(samples.cpu()).mean(dim=0)\n",
    "                        scale = torch.tensor(samples.cpu()).std(dim=0)\n",
    "                        loss = -Normal(loc, scale).log_prob(y.cpu()).sum()\n",
    "                        fids.append(loss)\n",
    "                    except ValueError:\n",
    "                        continue\n",
    "                    \n",
    "                    \"\"\"\n",
    "                    fid_samples = np.array(torch.cat(samples, dim=0).cpu())\n",
    "                    fid_samples_2 = np.array(y.cpu())\n",
    "\n",
    "                    mu1 = np.mean(fid_samples, axis=0)\n",
    "                    sigma1 = np.cov(fid_samples, rowvar=False)\n",
    "                    mu2 = np.mean(fid_samples_2, axis=0)\n",
    "                    sigma2 = np.cov(fid_samples_2, rowvar=False)\n",
    "\n",
    "                    diff = mu1 - mu2\n",
    "                    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "                    tr_covmean = np.trace(covmean)\n",
    "                    fid = (diff.dot(diff) + np.trace(sigma1) +  np.trace(sigma2) - 2 * tr_covmean)\n",
    "                    fids.append(fid.real)\n",
    "                    fids2.append(fid.real / np.var(fid_samples_2))\n",
    "                    \"\"\"\n",
    "                    \n",
    "                print(np.mean(fids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "8d146def",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-5980.333333333333, 2570.964842665536)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = [-3574, -9544, -4823]\n",
    "np.mean(a), np.std(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f41063b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6e5e5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f87f12bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82bd73f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efad248f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
