{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3bd7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd972eac",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6261775b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "import wandb\n",
    "from src.costs.lse import MLPLSECost\n",
    "from src.models.gmm_based import GMMEOT\n",
    "from src.plotting.distributions import plot_swiss_roll\n",
    "from src.plotting.parameters import (\n",
    "    plot_A_parameters,\n",
    "    plot_B_parameters,\n",
    "    plot_Z_parameters,\n",
    ")\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "from src.utils.train import compute_loss, update_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ea4743c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e79cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_device(device)\n",
    "dtype = torch.float64\n",
    "torch.torch.set_default_dtype(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f036daa0",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d0ecf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs.gmm_based.cost import MLPLSECostConfig\n",
    "from configs.gmm_based.dataset import DatasetConfig, MiniBatchConfig\n",
    "from configs.gmm_based.optimizer import OptPairedConfig, OptUnpairedConfig\n",
    "from configs.gmm_based.train import TrainConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e77c92",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Data\n",
    "Q_X_UNPAIRED_SAMPLES = 16000 # 1024\n",
    "R_Y_UNPAIRED_SAMPLES = 16000 # 1024\n",
    "P_XY_PAIRED_SAMPLES = 16000 # 128\n",
    "\n",
    "# Optimizer\n",
    "LR_PAIRED = 3e-4\n",
    "LR_UNPAIRED = 1e-3\n",
    "\n",
    "# Sampler\n",
    "PAIRED_BATCH_SIZE = 128\n",
    "UNPAIRED_BATCH_SIZE = 128\n",
    "\n",
    "# Train\n",
    "MAX_STEPS = 26000\n",
    "INIT_BY_SAMPLES = True\n",
    "\n",
    "# Potential\n",
    "Y_DIM = 2\n",
    "N_POTENTIALS = 50\n",
    "\n",
    "# Cost\n",
    "M_POTENTIALS = 25\n",
    "LOG_V_M_HIDDEN_CHANNELS = [M_POTENTIALS]\n",
    "B_M_HIDDEN_CHANNELS = [M_POTENTIALS * Y_DIM]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c07b9d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = DatasetConfig(\n",
    "    P_XY_paired=P_XY_PAIRED_SAMPLES, Q_X_unpaired=Q_X_UNPAIRED_SAMPLES, R_Y_unpaired=R_Y_UNPAIRED_SAMPLES\n",
    ")\n",
    "minibatch_config = MiniBatchConfig()\n",
    "\n",
    "cost_config = MLPLSECostConfig(\n",
    "    m_potentials=M_POTENTIALS,\n",
    "    log_v_m_hidden_channels=LOG_V_M_HIDDEN_CHANNELS,\n",
    "    b_m_hidden_channels=B_M_HIDDEN_CHANNELS,\n",
    ")\n",
    "EXP_META_INFO = (\n",
    "    f\"M_POTENTIALS_{M_POTENTIALS}_\"\n",
    "    + f\"LOG_V_M_HIDDEN_CHANNELS_{LOG_V_M_HIDDEN_CHANNELS}_\"\n",
    "    + f\"B_M_HIDDEN_CHANNELS_{B_M_HIDDEN_CHANNELS}_\"\n",
    ")\n",
    "\n",
    "opt_unpaired_config = OptUnpairedConfig(lr=LR_UNPAIRED)\n",
    "opt_paired_config = OptPairedConfig(lr=LR_PAIRED)\n",
    "\n",
    "train_config = TrainConfig(\n",
    "    steps_to=MAX_STEPS, paired_batch_size=PAIRED_BATCH_SIZE, unpaired_batch_size=UNPAIRED_BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "339530f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(train_config.seed)\n",
    "np.random.seed(train_config.seed)\n",
    "random.seed(train_config.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c47f945",
   "metadata": {},
   "source": [
    "## 3. Create data and samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5973a562",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = StandardNormalSampler(dim=dataset_config.x_dim, device=device)\n",
    "Y_sampler = SwissRollSampler(dim=dataset_config.y_dim, device=device, dtype=dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1d2f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "otp_sampler = OTPlanSampler(**minibatch_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d235fb71",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"checkpoints/Tensors\"\n",
    "file_postfix = f\"{minibatch_config.cost_function}_{dataset_config.P_XY_paired}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86defff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_paired_train, Y_paired_train, X_paired_test, Y_paired_test = generate_paired_data(\n",
    "    X_sampler, Y_sampler, otp_sampler, dataset_config.P_XY_paired, \"./checkpoints/Tensors\", file_postfix, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f61b8693",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_train_sampler = get_paired_sampler(\n",
    "    X_paired_train, Y_paired_train, train_config.paired_batch_size, dataset_config.P_XY_paired, device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09ffa079",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_unpaired_test = X_sampler.sample(dataset_config.P_XY_paired)\n",
    "Y_unpaired_test = Y_sampler.sample(dataset_config.P_XY_paired)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe419b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_config.Q_X_unpaired > 0:\n",
    "    source_data = X_sampler.sample(dataset_config.Q_X_unpaired)\n",
    "    usd_sampler = DatasetSampler(source_data, device=device) # usd - unpaired source data\n",
    "else:\n",
    "    usd_sampler = DatasetSampler(X_paired_train, device=device)\n",
    "\n",
    "if dataset_config.R_Y_unpaired > 0:\n",
    "    target_data = Y_sampler.sample(dataset_config.R_Y_unpaired)\n",
    "    utd_sampler = DatasetSampler(target_data, device=device) # utd - unpaired target data\n",
    "else:\n",
    "    utd_sampler = DatasetSampler(Y_paired_train, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd8d737",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e538192b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cost = MLPLSECost(**cost_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9172cfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GMMEOT(\n",
    "    y_dim=Y_DIM,\n",
    "    n_potentials=N_POTENTIALS,\n",
    "    cost=cost,\n",
    ").to(dtype)\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    model.init_a_by_samples(Y_sampler.sample(N_POTENTIALS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "329756e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For EMA update\n",
    "if train_config.ema_update:\n",
    "    model_copy = GMMEOT(\n",
    "    y_dim=Y_DIM,\n",
    "    n_potentials=N_POTENTIALS,\n",
    "    cost=cost,\n",
    ").to(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22707c22",
   "metadata": {},
   "source": [
    "## 5. Optimizers initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4c4ee69",
   "metadata": {},
   "outputs": [],
   "source": [
    "unpaired_params_to_update = [model._log_w_n, model._a_n, model._log_A_n]\n",
    "\n",
    "D_opt_unpaired = torch.optim.Adam(unpaired_params_to_update, **opt_unpaired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c060ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_paired = torch.optim.Adam(model.cost.parameters(), **opt_paired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0d317b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: refactor this config\n",
    "EXP_NAME = (\n",
    "    \"GMMEOT_Swiss_Roll_\"\n",
    "    + f\"P_XY_PAIRED_{dataset_config.P_XY_paired}_\"\n",
    "    + f\"Q_X_UNPAIRED_{dataset_config.Q_X_unpaired}_\"\n",
    "    + f\"R_Y_UNPAIRED_{dataset_config.R_Y_unpaired}_\"\n",
    "    + f\"LR_PAIRED_{opt_paired_config.lr}_\"\n",
    "    + f\"LR_UNPAIRED_{opt_unpaired_config.lr}_\"\n",
    "    + f\"MINIBATCH_COST_{minibatch_config.cost_function}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    X_DIM=dataset_config.x_dim,\n",
    "    Y_DIM=dataset_config.y_dim,\n",
    "    D_LR_PAIRED=opt_paired_config.lr,\n",
    "    D_LR_UNPAIRED=opt_unpaired_config.lr,\n",
    "    BATCH_SIZE=train_config.unpaired_batch_size,\n",
    "    P_XY_PAIRED_SAMPLES=dataset_config.P_XY_paired,\n",
    "    Q_X_UNPAIRED_SAMPLES=dataset_config.Q_X_unpaired,\n",
    "    R_Y_UNPAIRED_SAMPLES=dataset_config.R_Y_unpaired,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b726194c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train_config.steps_from > 0:\n",
    "    D_opt_unpaired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{train_config.steps_from}.pt\")))\n",
    "    D_opt_paired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_paired_{train_config.steps_from}.pt\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71524ecb",
   "metadata": {},
   "source": [
    "## 6. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f65e6586",
   "metadata": {},
   "outputs": [],
   "source": [
    "starting_points = torch.tensor([[-2.0, 0.0], [2.0, 2.0], [0.0, 0.0]])\n",
    "num_ending_points = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e32f4804",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_starting_points_paired = 5\n",
    "indices = random.choices(range(dataset_config.P_XY_paired), k=num_starting_points_paired)\n",
    "starting_points_paired = X_paired_train[indices]\n",
    "ending_points_paired = Y_paired_train[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c512a511",
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_Y_points = get_GT_points(X_sampler, Y_sampler, otp_sampler, starting_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adf4c2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(train_config.steps_from, train_config.steps_to)):\n",
    "    # training loop\n",
    "    D_opt_unpaired.zero_grad()\n",
    "\n",
    "    X = usd_sampler.sample(train_config.unpaired_batch_size)\n",
    "    Y = utd_sampler.sample(train_config.unpaired_batch_size)\n",
    "\n",
    "    output_unpaired = model.compute_unpaired_loss(X, Y)\n",
    "    D_loss_unpaired = output_unpaired[\"loss\"]\n",
    "\n",
    "    wandb.log({f\"Unpaired loss\": D_loss_unpaired.item()}, step=step)\n",
    "\n",
    "    D_opt_paired.zero_grad()\n",
    "    X_paired, Y_paired = pd_train_sampler.sample(train_config.paired_batch_size)\n",
    "    \n",
    "    output_paired = model.compute_paired_loss(X_paired, Y_paired)\n",
    "    D_loss_paired = output_paired[\"loss\"]\n",
    "\n",
    "    wandb.log({f\"Paired loss\": D_loss_paired.item()}, step=step)\n",
    "\n",
    "    D_loss = D_loss_unpaired + D_loss_paired\n",
    "    D_loss.backward()\n",
    "    D_opt_paired.step()\n",
    "    D_opt_unpaired.step()\n",
    "\n",
    "    if train_config.ema_update:\n",
    "        update_average(model_copy, model, 0.99)\n",
    "        model = model_copy\n",
    "    else:\n",
    "        model = model\n",
    "\n",
    "    wandb.log({f\"Loss\": D_loss}, step=step)\n",
    "    wandb.log(\n",
    "        {f\"Train paired loss\": compute_loss(model, X_paired_train, Y_paired_train, X_paired_train, Y_paired_train)},\n",
    "        step=step,\n",
    "    )\n",
    "    wandb.log(\n",
    "        {f\"Test paired loss\": compute_loss(model, X_paired_test, Y_paired_test, X_paired_test, Y_paired_test)},\n",
    "        step=step,\n",
    "    )\n",
    "    wandb.log(\n",
    "        {f\"Test unpaired loss\": compute_loss(model, X_unpaired_test, Y_unpaired_test, X_paired_test, Y_paired_test)},\n",
    "        step=step,\n",
    "    )\n",
    "\n",
    "    wandb.log({r\"$-f^c(x)$\": -output_unpaired[\"f_c\"].mean().item()}, step=step)\n",
    "    wandb.log({r\"$-f(y)$\": -output_unpaired[\"f\"].mean().item()}, step=step)\n",
    "    wandb.log({f\"lam_min(A_n)\": torch.min(output_unpaired[\"A_n\"])}, step=step)\n",
    "    wandb.log({f\"lam_max(A_n)\": torch.max(output_unpaired[\"A_n\"])}, step=step)\n",
    "\n",
    "    if step % train_config.plot_every == 0:\n",
    "        A_dict = plot_A_parameters(model, log=True)\n",
    "        B_dict = plot_B_parameters(model.cost, starting_points, log=True)\n",
    "        if num_starting_points_paired > 0:\n",
    "            Z_dict = plot_Z_parameters(model, starting_points, starting_points_paired, ending_points_paired, log=True)\n",
    "        else:\n",
    "            Z_dict = plot_Z_parameters(model, starting_points, log=True)\n",
    "        distr_dict = plot_swiss_roll(\n",
    "            {f\"P={P_XY_PAIRED_SAMPLES}, Q={Q_X_UNPAIRED_SAMPLES}, R={R_Y_UNPAIRED_SAMPLES}\": model},\n",
    "            X_sampler,\n",
    "            Y_sampler,\n",
    "            X_paired,\n",
    "            Y_paired,\n",
    "            starting_points,\n",
    "            gt_Y_points,\n",
    "            log=True,\n",
    "        )\n",
    "        wandb.log(A_dict | B_dict | Z_dict | distr_dict)\n",
    "\n",
    "        torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{step}.pt\"))\n",
    "\n",
    "torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{MAX_STEPS}.pt\"))\n",
    "torch.save(D_opt_paired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_paired_{MAX_STEPS}.pt\"))\n",
    "torch.save(D_opt_unpaired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{MAX_STEPS}.pt\"))\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a8cd779",
   "metadata": {},
   "source": [
    "## 7. Naive Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2414e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "LR_NAIVE = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63ec4f65",
   "metadata": {},
   "outputs": [],
   "source": [
    "cost = MLPLSECost(**cost_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c6f4839",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GMMEOT(\n",
    "    y_dim=Y_DIM,\n",
    "    n_potentials=N_POTENTIALS,\n",
    "    cost=cost,\n",
    ").to(dtype)\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    model.init_a_by_samples(Y_sampler.sample(N_POTENTIALS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d9da9c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For EMA update\n",
    "if train_config.ema_update:\n",
    "    model_copy = GMMEOT(\n",
    "    y_dim=Y_DIM,\n",
    "    n_potentials=N_POTENTIALS,\n",
    "    cost=cost,\n",
    ").to(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bd5002f",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_naive = torch.optim.Adam(model.parameters(), lr=LR_NAIVE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b141dcad",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dee8cc6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(train_config.steps_from, train_config.steps_to)):\n",
    "    # training loop\n",
    "    D_opt_naive.zero_grad()\n",
    "\n",
    "    X = usd_sampler.sample(train_config.unpaired_batch_size)\n",
    "    Y = utd_sampler.sample(train_config.unpaired_batch_size)\n",
    "\n",
    "    log_w_n = model.log_w_n()\n",
    "    a_n = model.a_n()\n",
    "    A_n = model.A_n()\n",
    "\n",
    "    cond_distr_unpaired = model.get_conditional_distribution(\n",
    "        X.repeat(train_config.unpaired_batch_size, 1), log_w_n, a_n, A_n\n",
    "    )\n",
    "    fwd = cond_distr_unpaired.log_prob(Y.repeat(train_config.unpaired_batch_size, 1))\n",
    "    D_loss_unpaired = -torch.log(\n",
    "        torch.mean(torch.exp(fwd.reshape(train_config.unpaired_batch_size, train_config.unpaired_batch_size)), dim=-1)\n",
    "    ).mean()\n",
    "\n",
    "    wandb.log({f\"Unpaired loss\": D_loss_unpaired.item()}, step=step)\n",
    "\n",
    "    X_paired, Y_paired = pd_train_sampler.sample(train_config.paired_batch_size)\n",
    "\n",
    "    cond_distr_paired = model.get_conditional_distribution(X_paired, log_w_n, a_n, A_n)\n",
    "    D_loss_paired = -cond_distr_paired.log_prob(Y_paired).mean()\n",
    "\n",
    "    wandb.log({f\"Paired loss\": D_loss_paired.item()}, step=step)\n",
    "\n",
    "    D_loss = D_loss_unpaired + D_loss_paired\n",
    "    D_loss.backward()\n",
    "    D_opt_naive.step()\n",
    "\n",
    "    if train_config.ema_update:\n",
    "        update_average(model_copy, model, 0.99)\n",
    "        model = model_copy\n",
    "    else:\n",
    "        model = model\n",
    "\n",
    "    wandb.log({f\"Loss\": D_loss}, step=step)\n",
    "    wandb.log(\n",
    "        {f\"Train paired loss\": compute_loss(model, X_paired_train, Y_paired_train, X_paired_train, Y_paired_train)},\n",
    "        step=step,\n",
    "    )\n",
    "    wandb.log(\n",
    "        {f\"Test paired loss\": compute_loss(model, X_paired_test, Y_paired_test, X_paired_test, Y_paired_test)},\n",
    "        step=step,\n",
    "    )\n",
    "    wandb.log(\n",
    "        {f\"Test unpaired loss\": compute_loss(model, X_unpaired_test, Y_unpaired_test, X_paired_test, Y_paired_test)},\n",
    "        step=step,\n",
    "    )\n",
    "\n",
    "    wandb.log({f\"lam_min(A_n)\": torch.min(A_n)}, step=step)\n",
    "    wandb.log({f\"lam_max(A_n)\": torch.max(A_n)}, step=step)\n",
    "\n",
    "    if step % train_config.plot_every == 0:\n",
    "        A_dict = plot_A_parameters(model, log=True)\n",
    "        B_dict = plot_B_parameters(model.cost, starting_points, log=True)\n",
    "        if num_starting_points_paired > 0:\n",
    "            Z_dict = plot_Z_parameters(model, starting_points, starting_points_paired, ending_points_paired, log=True)\n",
    "        else:\n",
    "            Z_dict = plot_Z_parameters(model, starting_points, log=True)\n",
    "        distr_dict = plot_swiss_roll(\n",
    "            {f\"P={P_XY_PAIRED_SAMPLES}, Q={Q_X_UNPAIRED_SAMPLES}, R={R_Y_UNPAIRED_SAMPLES}\": model},\n",
    "            X_sampler,\n",
    "            Y_sampler,\n",
    "            X_paired,\n",
    "            Y_paired,\n",
    "            starting_points,\n",
    "            gt_Y_points,\n",
    "            log=True,\n",
    "        )\n",
    "        wandb.log(A_dict | B_dict | Z_dict | distr_dict)\n",
    "\n",
    "        torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{step}.pt\"))\n",
    "\n",
    "torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{MAX_STEPS}.pt\"))\n",
    "torch.save(D_opt_paired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_paired_{MAX_STEPS}.pt\"))\n",
    "torch.save(D_opt_unpaired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{MAX_STEPS}.pt\"))\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0686dedb",
   "metadata": {},
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adc9ddd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "Q_X_unpaired_samples_list = [0, 1024]\n",
    "R_Y_unpaired_samples_list = [0, 1024]\n",
    "log_step = 99000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33eecefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dict = dict()\n",
    "\n",
    "for i, Q_X_unpaired_samples in enumerate(Q_X_unpaired_samples_list):\n",
    "    for j, R_Y_unpaired_samples in enumerate(R_Y_unpaired_samples_list):\n",
    "        model = GMMEOT(\n",
    "            y_dim=Y_DIM,\n",
    "            n_potentials=N_POTENTIALS,\n",
    "            cost=cost,\n",
    "        ).to(dtype)\n",
    "        exp_name = EXP_NAME.replace(\n",
    "            f\"Q_X_UNPAIRED_{Q_X_UNPAIRED_SAMPLES}_R_Y_UNPAIRED_{R_Y_UNPAIRED_SAMPLES}_\",\n",
    "            f\"Q_X_UNPAIRED_{Q_X_unpaired_samples}_R_Y_UNPAIRED_{R_Y_unpaired_samples}_\",\n",
    "        )\n",
    "        print(exp_name)\n",
    "        output_path = \"../checkpoints/{}\".format(exp_name)\n",
    "        model.load_state_dict(torch.load(os.path.join(output_path, f\"D_{log_step}.pt\"), map_location=device))\n",
    "        title = f\"P={P_XY_PAIRED_SAMPLES}, Q={Q_X_unpaired_samples}, R={R_Y_unpaired_samples}\"\n",
    "        models_dict[title] = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ea78482",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_swiss_roll(\n",
    "    models_dict,\n",
    "    X_sampler,\n",
    "    Y_sampler,\n",
    "    X_paired_train,\n",
    "    Y_paired_train,\n",
    "    starting_points,\n",
    "    gt_Y_points,\n",
    ") "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "text",
   "language": "python",
   "name": "text"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
