{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f3bd7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd972eac",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6261775b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import wandb\n",
    "from src.models.light_gcot import LightGCOT\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "from src.utils.plotting.distributions import plot_swiss_roll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5ea4743c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "53e79cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_device(device)\n",
    "dtype = torch.float64\n",
    "torch.torch.set_default_dtype(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f036daa0",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d8508840",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "X_DIM = 2\n",
    "Y_DIM = 2\n",
    "assert X_DIM > 1\n",
    "assert Y_DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 42\n",
    "\n",
    "N_POTENTIALS = 50\n",
    "M_POTENTIALS = 25\n",
    "EPSILON = 1.0\n",
    "INIT_BY_SAMPLES = True\n",
    "A_DIAGONAL_INIT = 0.1\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "D_LR_PAIRED = 3e-4\n",
    "D_LR_UNPAIRED = 1e-3 \n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "M_X_UNPAIRED_SAMPLES = 1024\n",
    "N_Y_UNPAIRED_SAMPLES = 1024\n",
    "L_PAIRED_SAMPLES = 128\n",
    "\n",
    "PLOT_EVERY = 1000\n",
    "MAX_STEPS = 100000\n",
    "CONTINUE = -1\n",
    "\n",
    "EXP_COST = \"MLP\"\n",
    "EXP_COST_INCLUDED = True\n",
    "MINIBATCH_COST = \"rotation-v2\"\n",
    "\n",
    "EMA_UPDATE = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "339530f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(OUTPUT_SEED)\n",
    "np.random.seed(OUTPUT_SEED)\n",
    "random.seed(OUTPUT_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "caf71c45",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_META_INFO = \"\"\n",
    "EXP_NAME = (\n",
    "    \"Light-GCOT_Swiss_Roll_\"\n",
    "    + f\"EPSILON_{EPSILON}_\"\n",
    "    + f\"MAX_STEPS_{MAX_STEPS}_\"\n",
    "    + f\"N_{N_POTENTIALS}_\"\n",
    "    + f\"M_{M_POTENTIALS}_\"\n",
    "    + f\"with_{EXP_COST}_cost_included_{EXP_COST_INCLUDED}_\"\n",
    "    + f\"M_X_UNPAIRED_{M_X_UNPAIRED_SAMPLES}_\"\n",
    "    + f\"N_Y_UNPAIRED_{N_Y_UNPAIRED_SAMPLES}_\"\n",
    "    + f\"L_PAIRED_{L_PAIRED_SAMPLES}_\"\n",
    "    + f\"LR_PAIRED_{D_LR_PAIRED}_\"\n",
    "    + f\"LR_UNPAIRED_{D_LR_UNPAIRED}_\"\n",
    "    + f\"MINIBATCH_COST_{MINIBATCH_COST}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c47f945",
   "metadata": {},
   "source": [
    "## 3. Create data and samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5973a562",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = StandardNormalSampler(dim=2, device=device)\n",
    "Y_sampler = SwissRollSampler(dim=2, device=device, dtype=dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fb1d2f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "otp_sampler = OTPlanSampler(\"sinkhorn\", cost_function=MINIBATCH_COST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d235fb71",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"checkpoints/Tensors\"\n",
    "file_postfix = f\"{MINIBATCH_COST}_{L_PAIRED_SAMPLES}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "86defff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_paired_train, Y_paired_train, X_paired_test, Y_paired_test = generate_paired_data(\n",
    "    X_sampler, Y_sampler, otp_sampler, L_PAIRED_SAMPLES, \"./checkpoints/Tensors\", file_postfix, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f61b8693",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_train_sampler = get_paired_sampler(X_paired_train, Y_paired_train, BATCH_SIZE, L_PAIRED_SAMPLES, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "09ffa079",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_unpaired_test = X_sampler.sample(L_PAIRED_SAMPLES)\n",
    "Y_unpaired_test = Y_sampler.sample(L_PAIRED_SAMPLES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "afe419b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if M_X_UNPAIRED_SAMPLES > 0:\n",
    "    source_data = X_sampler.sample(M_X_UNPAIRED_SAMPLES)\n",
    "    usd_sampler = DatasetSampler(source_data, device=device) # usd - unpaired source data\n",
    "else:\n",
    "    usd_sampler = DatasetSampler(X_paired_train, device=device)\n",
    "\n",
    "if N_Y_UNPAIRED_SAMPLES > 0:\n",
    "    target_data = Y_sampler.sample(N_Y_UNPAIRED_SAMPLES)\n",
    "    utd_sampler = DatasetSampler(target_data, device=device) # utd - unpaired target data\n",
    "else:\n",
    "    utd_sampler = DatasetSampler(Y_paired_train, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71524ecb",
   "metadata": {},
   "source": [
    "## 4. Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "f65e6586",
   "metadata": {},
   "outputs": [],
   "source": [
    "# starting_points = torch.tensor([[-2.0, 0.0], [2.0, 2.0], [0.0, -2.0]])\n",
    "starting_points = torch.tensor([[-2.0, 0.0], [2.0, 2.0], [-0.0, -0.0]])\n",
    "# starting_points = torch.tensor([[-1.75, 0.0], [2.0, 2.0], [0.25, 0.0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "c512a511",
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_Y_points = get_GT_points(X_sampler, Y_sampler, otp_sampler, starting_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "adc9ddd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_X_unpaired_samples_list = [0, 1024]\n",
    "N_Y_unpaired_samples_list = [0, 1024]\n",
    "log_steps = [[23000, 23000], [23000, 23000]] # [[66000, 65000], [145000, 138000]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "571f673f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, M_X_unpaired_samples in enumerate(M_X_unpaired_samples_list):\n",
    "    for j, N_Y_unpaired_samples in enumerate(N_Y_unpaired_samples_list):\n",
    "        print(i, j, M_X_unpaired_samples, N_Y_unpaired_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "33eecefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dict = dict()\n",
    "\n",
    "for i, M_X_unpaired_samples in enumerate(M_X_unpaired_samples_list):\n",
    "    for j, N_Y_unpaired_samples in enumerate(N_Y_unpaired_samples_list):\n",
    "        model = LightGCOT(\n",
    "            x_dim=X_DIM,\n",
    "            y_dim=Y_DIM,\n",
    "            n_potentials=N_POTENTIALS,\n",
    "            m_potentials=M_POTENTIALS,\n",
    "            epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE,\n",
    "            A_diagonal_init=A_DIAGONAL_INIT,\n",
    "            cost_function=EXP_COST,\n",
    "        )\n",
    "        exp_name = EXP_NAME.replace(\n",
    "            f\"M_X_UNPAIRED_{M_X_UNPAIRED_SAMPLES}_N_Y_UNPAIRED_{N_Y_UNPAIRED_SAMPLES}_\",\n",
    "            f\"M_X_UNPAIRED_{M_X_unpaired_samples}_N_Y_UNPAIRED_{N_Y_unpaired_samples}_\",\n",
    "        )\n",
    "        print(exp_name)\n",
    "        output_path = \"../checkpoints/{}\".format(exp_name)\n",
    "        model.load_state_dict(torch.load(os.path.join(output_path, f\"D_{log_steps[i][j]}.pt\"), map_location=device))\n",
    "        title = f\"M={M_X_unpaired_samples}, N={N_Y_unpaired_samples}, L={L_PAIRED_SAMPLES}\"\n",
    "        models_dict[title] = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "13cb0a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_DIR = \"./plots/Swiss_Roll/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "3bc6d584",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_swiss_roll(\n",
    "    models_dict,\n",
    "    X_sampler,\n",
    "    Y_sampler,\n",
    "    X_paired_train,\n",
    "    Y_paired_train,\n",
    "    starting_points,\n",
    "    gt_Y_points,\n",
    "    save_dir=SAVE_DIR,\n",
    ") "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "light-gcot",
   "language": "python",
   "name": "light-gcot"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
