{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e1bc3506-f509-4b37-8179-15e7a944c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "sys.path.append('/home/quickjkee/projects/Light-GCOT')\n",
    "\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "from sklearn.preprocessing import Normalizer\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "import moscot.plotting as mtp\n",
    "import scipy\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from moscot import datasets\n",
    "from moscot.problems.cross_modality import TranslationProblem\n",
    "from sklearn import preprocessing as pp\n",
    "\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.from_loader import PairedLoaderSampler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.models.light_gcot import LightGCOT\n",
    "from sklearn.manifold import TSNE\n",
    "tsne = TSNE(n_components=2, random_state=50)\n",
    "\n",
    "#https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9a6ad3",
   "metadata": {},
   "source": [
    "# Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2366fd9",
   "metadata": {},
   "source": [
    "### PS\n",
    "1) Source  \n",
    "$X \\in \\mathbb{R}^{N \\times d_{1}}, N - \\text{number of locations}, d_{1} - \\text{features dim}$ \\\n",
    "$x = (\\mu, \\sigma) - \\text{for a given location in June}$ \\\n",
    "$N = 1396, d_{1} = 188$ \n",
    "\n",
    "2) $Y \\in \\mathbb{R}^{N \\times M \\times d_{2}}, N - \\text{number of locations}, M - \\text{measurements for a given location in January by day}$ \\\n",
    "$M = [1, 31], d_{2} = 94$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02965c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################################\n",
    "#-------------- RAW DATA -----------------\n",
    "##########################################\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "root = '../tabred/kal/weather'\n",
    "\n",
    "data = np.load(f'{root}/X_num.npy')\n",
    "data = np.stack([d for d in data if sum(np.isnan(d)) == 0])\n",
    "data_csv = pd.read_csv(f'{root}/csv/X_num.csv')\n",
    "#train_data = data[train_idx]\n",
    "#test_data = data[test_idx]\n",
    "\n",
    "target = np.load(f'{root}/Y.npy')\n",
    "meta = np.load(f'{root}/X_meta.npy')\n",
    "meta = np.stack([meta[i] for i, d in enumerate(data) if sum(np.isnan(d)) == 0])\n",
    "meta_csv = pd.read_csv(f'{root}/csv/X_meta.csv')\n",
    "\n",
    "names = list(data_csv.columns)\n",
    "names.append('location')\n",
    "data_new = np.concatenate((data, meta[:, -2].reshape(-1, 1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb1dee44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1653 1578\n"
     ]
    }
   ],
   "source": [
    "#########################################################\n",
    "#--------------- Month/location splitted ---------------- \n",
    "#########################################################\n",
    "scaler = StandardScaler()\n",
    "\n",
    "dict_location_src = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 1.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_src[d[-1]] = []\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "     \n",
    "\n",
    "dict_location_src_new = {}\n",
    "for key in dict_location_src.keys():\n",
    "    item = dict_location_src[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_src_new[key] = item\n",
    "dict_location_src = dict_location_src_new\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "dict_location_trg = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 6.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_trg[d[-1]] = []\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "    \n",
    "dict_location_trg_new = {}\n",
    "for key in dict_location_trg.keys():\n",
    "    item = dict_location_trg[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_trg_new[key] = item\n",
    "dict_location_trg = dict_location_trg_new\n",
    "\n",
    "print(len(dict_location_trg), len(dict_location_src))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c5c56fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################\n",
    "#--------------------- X, Y paired ----------------------\n",
    "#########################################################\n",
    "\n",
    "chosen_locs = list(dict_location_trg.keys())[:200]\n",
    "X_pair_orig, Y_pair_orig = [], []\n",
    "for key in dict_location_src.keys():\n",
    "    if key not in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_pair_orig.append(x)\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_pair_orig.append(item_trg) # sample\n",
    "X_pair_orig = np.stack(X_pair_orig)\n",
    "\n",
    "\n",
    "#########################################################\n",
    "#----------------------- X, Y ---------------------------\n",
    "#########################################################\n",
    "\n",
    "# N x 1 x 2D - src\n",
    "# N x M x D - trg\n",
    " \n",
    "# sampling: \n",
    "# b x 1 x 2D,\n",
    "# b x M x D -> sample -> b x 1 x D\n",
    "\n",
    "X_orig = []\n",
    "for key in dict_location_src.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_orig.append(x)\n",
    "X_orig = np.stack(X_orig)\n",
    "\n",
    "Y_orig = []\n",
    "for key in dict_location_trg.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_orig.append(item_trg) # sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b44ab264",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1386, 188) 1453 (192, 188) 192\n"
     ]
    }
   ],
   "source": [
    "print(X_orig.shape, len(Y_orig), X_pair_orig.shape, len(Y_pair_orig))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b6332b",
   "metadata": {},
   "source": [
    "# Running"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "f0a1831a-9e8f-480c-abb1-da3f6dc56ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = X_orig\n",
    "target_data = Y_orig[0]\n",
    "X_DIM = source_data.shape[1]\n",
    "Y_DIM = target_data.shape[1]\n",
    "#X_DIM = data_set[\"features\"].shape[1]\n",
    "#Y_DIM = data_set[\"features\"].shape[1]\n",
    "assert X_DIM > 1\n",
    "assert Y_DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 50\n",
    "\n",
    "N_POTENTIALS = 10\n",
    "M_POTENTIALS = 1 #10\n",
    "EPSILON = 1\n",
    "A_DIAGONAL_INIT = 0.5\n",
    "L_PAIRED_SAMPLES = len(X_pair_orig)\n",
    "M_X_UNPAIRED_SAMPLES = 0\n",
    "N_Y_UNPAIRED_SAMPLES = 0\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "D_LR = 3e-4  # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "NUM_LABELED = 10\n",
    "TRAIN_SUBSET_SIZE = 2\n",
    "\n",
    "PLOT_EVERY = 1000\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "dc8ded78-a9b3-4053-972e-e0fbed555f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_COST = \"MLP_deep_deep\"\n",
    "EXP_COST_INCLUDED = True\n",
    "EXP_META_INFO = \"\"\n",
    "EXP_NAME = (\n",
    "    f\"Light-GCOT_Batch_Effect_\"\n",
    "    + f\"EPSILON_{EPSILON}_\"\n",
    "    + f\"N_{N_POTENTIALS}_\"\n",
    "    + f\"M_{M_POTENTIALS}_\"\n",
    "    + f\"with_{EXP_COST}_\"\n",
    "    + f\"cost_included_{EXP_COST_INCLUDED}_\"\n",
    "    + f\"N_PAIRED_{NUM_LABELED}_\"\n",
    "    + f\"M_UNPAIRED_{len(source_data)}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    X_DIM=X_DIM,\n",
    "    Y_DIM=Y_DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    M_POTENTIALS=M_POTENTIALS,\n",
    "    A_DIAGONAL_INIT=A_DIAGONAL_INIT,\n",
    "    N_PAIRED_SAMPLES=NUM_LABELED,\n",
    "    M_UNPAIRED_SAMPLES=len(source_data),\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "cb4506b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pytorch_total_params = sum(p.numel() for p in D.parameters())\n",
    "#pytorch_total_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d32db3-1805-4e51-b492-e1eaaba959e6",
   "metadata": {},
   "source": [
    "## Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "86eea788",
   "metadata": {},
   "outputs": [],
   "source": [
    "def paired_sampler(X_pair, Y_pair, b_size):\n",
    "    idxs = np.random.randint(low=0, high=len(X_pair)-1, size=b_size)\n",
    "    x_pair_batch = torch.tensor(X_pair[idxs]).to('cuda')\n",
    "    y_pair_batch = np.stack([Y_pair[idx][random.randint(0, len(Y_pair[idx])-1)] for idx in idxs])\n",
    "    y_pair_batch = torch.tensor(y_pair_batch).to('cuda')\n",
    "    return x_pair_batch.to(torch.float32), y_pair_batch.to(torch.float32)\n",
    "\n",
    "def unpaired_sampler(X, Y, b_size):\n",
    "    # UNPAIRED SAMPLER\n",
    "    idxs = np.random.randint(low=0, high=len(X)-1, size=b_size)\n",
    "    idxs_y = np.array([len(X) - idx - 1 for idx in idxs])\n",
    "\n",
    "    x_batch = torch.tensor(X[idxs]).to('cuda')\n",
    "    y_batch = np.stack([Y[idx][random.randint(0, len(Y[idx])-1)] for idx in idxs_y])\n",
    "    y_batch = torch.tensor(y_batch).to('cuda')\n",
    "    return x_batch.to(torch.float32), y_batch.to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "7d300dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models.models import MyCDiscriminator, MyCGenerator\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "import torch.nn.functional as F\n",
    "from src.models.models import ConditionalRealNVP, MLPnet\n",
    "from scipy import linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "0f272ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "T = MLPnet(input_size=X_DIM, hidden_size=Y_DIM, num_hidden_layers=1).to(device)\n",
    "\n",
    "T_opt_paired = torch.optim.Adam(T.parameters(), lr=3e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "2131aaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "history = {\n",
    "        \"D_loss\": [],\n",
    "        \"G_loss\": [],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "f5391bb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▏                                       | 52/10000 [00:01<02:58, 55.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.52428177045584\n",
      "286.6372916119782\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|███▉                                 | 1060/10000 [00:04<00:59, 149.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17.51593466264451\n",
      "174.67337380725266\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 21%|███████▋                             | 2067/10000 [00:07<00:57, 138.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.807844453113132\n",
      "136.97441838744032\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 31%|███████████▎                         | 3072/10000 [00:11<00:50, 137.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.909871560291197\n",
      "117.7053563036969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 41%|██████████████▉                      | 4054/10000 [00:14<00:45, 130.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.733211996349372\n",
      "105.73053550539424\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 51%|██████████████████▋                  | 5053/10000 [00:18<00:35, 137.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.932359580779792\n",
      "97.60321023270286\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 61%|██████████████████████▌              | 6090/10000 [00:21<00:27, 143.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.342642106402359\n",
      "91.61265993764118\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████           | 7039/10000 [00:24<00:27, 109.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.892257655836964\n",
      "87.03693288910888\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 81%|█████████████████████████████▊       | 8073/10000 [00:27<00:13, 146.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.53606527207051\n",
      "83.41189754064465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 91%|█████████████████████████████████▍   | 9052/10000 [00:31<00:06, 145.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.253329190098174\n",
      "80.5195295109213\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████| 10000/10000 [00:33<00:00, 302.31it/s]\n"
     ]
    }
   ],
   "source": [
    "MAX_STEPS = 10000\n",
    "stats = []\n",
    "D_loss = []\n",
    "fids = []\n",
    "fids2 = []\n",
    "device = 'cuda'\n",
    "\n",
    "# Splitting\n",
    "L_PAIRED_SAMPLES = 90\n",
    "L_UNPAIRED_SAMPLES = 500\n",
    "X, Y = X_orig[:L_UNPAIRED_SAMPLES], Y_orig[-L_UNPAIRED_SAMPLES:]\n",
    "X_pair, Y_pair = X_pair_orig[:L_PAIRED_SAMPLES], Y_pair_orig[:L_PAIRED_SAMPLES]\n",
    "X_pair_test, Y_pair_test = X_pair_orig[-100:], Y_pair_orig[-100:]\n",
    "\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "        T_opt_paired.zero_grad()\n",
    "        X_paired, Y_paired = paired_sampler(X_pair, Y_pair, BATCH_SIZE)\n",
    "        T_loss = F.mse_loss(Y_paired, T(X_paired))\n",
    "        T_loss.backward()\n",
    "        T_opt_paired.step()\n",
    "        #print(f\"Loss: {T_loss}\")\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            if step % 1000 == 0:\n",
    "                for x, y in zip(X_pair_test, Y_pair_test):\n",
    "                    x = torch.tensor(x).unsqueeze(0).to(device)\n",
    "                    y = torch.tensor(y).to(device)\n",
    "                    samples = []\n",
    "                    for _ in range(len(y)):\n",
    "                        sample = T(x.to(torch.float) + 0.1 * torch.randn_like(x.to(torch.float))).squeeze()\n",
    "                        samples.append(sample)\n",
    "                    sample = torch.stack(samples)\n",
    "                    fid_samples = np.array(sample.cpu()) #np.array(torch.cat(samples, dim=0).cpu())\n",
    "                    fid_samples_2 = np.array(y.cpu())\n",
    "\n",
    "                    mu1 = np.mean(fid_samples, axis=0)\n",
    "                    sigma1 = np.cov(fid_samples, rowvar=False)\n",
    "                    mu2 = np.mean(fid_samples_2, axis=0)\n",
    "                    sigma2 = np.cov(fid_samples_2, rowvar=False)\n",
    "\n",
    "                    diff = mu1 - mu2\n",
    "                    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "                    tr_covmean = np.trace(covmean)\n",
    "                    fid = (diff.dot(diff) + np.trace(sigma1) +  np.trace(sigma2) - 2 * tr_covmean)\n",
    "                    fids.append(fid.real)\n",
    "                    fids2.append(fid.real / np.var(fid_samples_2))\n",
    "                    \n",
    "                print(np.mean(fids))\n",
    "                print(np.mean(fids2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f41063b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6e5e5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f87f12bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82bd73f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efad248f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
