{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f3bd7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd972eac",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6261775b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "sys.path.append(\"./ALAE\")\n",
    "\n",
    "import random\n",
    "\n",
    "from comet_ml import Experiment\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "from src.costs.lse import MLPLSECost\n",
    "from src.models.gmm_based import GMMEOT\n",
    "from src.models.light_sbm import LightSBM\n",
    "from src.plotting.parameters import (\n",
    "    plot_A_parameters,\n",
    "    plot_B_parameters,\n",
    "    plot_Z_parameters,\n",
    ")\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.utils.train import compute_loss, update_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5ea4743c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device(f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "53e79cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_device(device)\n",
    "# dtype = torch.float64\n",
    "dtype = torch.float32\n",
    "# torch.torch.set_default_dtype(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f036daa0",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "13d0ecf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs.gmm_based.cost import MLPLSECostConfig\n",
    "from configs.gmm_based.optimizer import OptPairedConfig, OptUnpairedConfig\n",
    "from configs.gmm_based.train import TrainConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "34e77c92",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Data Type\n",
    "X_DIM = 512\n",
    "Y_DIM = 512\n",
    "INPUT_DATA = \"WOMAN\" # \"MAN\" # MAN, WOMAN, ADULT, CHILDREN\n",
    "TARGET_DATA = \"MAN\" # \"WOMAN\" # MAN, WOMAN, ADULT, CHILDREN\n",
    "\n",
    "# Data\n",
    "Q_X_UNPAIRED_SAMPLES = 48786 # 1024\n",
    "R_Y_UNPAIRED_SAMPLES = 10762 # 1024\n",
    "P_XY_PAIRED_SAMPLES = 2000 # 128\n",
    "\n",
    "# Optimizer\n",
    "LR_PAIRED = 3e-4\n",
    "LR_UNPAIRED = 1e-3\n",
    "\n",
    "# Sampler\n",
    "PAIRED_BATCH_SIZE = 512\n",
    "UNPAIRED_BATCH_SIZE = 512\n",
    "\n",
    "# Train\n",
    "MAX_STEPS = 10000\n",
    "INIT_BY_SAMPLES = True\n",
    "\n",
    "# Potential\n",
    "N_POTENTIALS = 10\n",
    "\n",
    "# Cost\n",
    "M_POTENTIALS = 1\n",
    "LOG_V_M_HIDDEN_CHANNELS = [M_POTENTIALS]\n",
    "B_M_HIDDEN_CHANNELS = [M_POTENTIALS * Y_DIM]\n",
    "\n",
    "SEED = 44"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0c07b9d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "cost_config = MLPLSECostConfig(\n",
    "    x_dim=X_DIM,\n",
    "    y_dim=Y_DIM,\n",
    "    m_potentials=M_POTENTIALS,\n",
    "    log_v_m_hidden_channels=LOG_V_M_HIDDEN_CHANNELS,\n",
    "    b_m_hidden_channels=B_M_HIDDEN_CHANNELS,\n",
    ")\n",
    "EXP_META_INFO = (\n",
    "    f\"M_POTENTIALS_{M_POTENTIALS}_\"\n",
    "    + f\"LOG_V_M_HIDDEN_CHANNELS_{LOG_V_M_HIDDEN_CHANNELS}_\"\n",
    "    + f\"B_M_HIDDEN_CHANNELS_{B_M_HIDDEN_CHANNELS}_\"\n",
    ")\n",
    "\n",
    "opt_unpaired_config = OptUnpairedConfig(lr=LR_UNPAIRED)\n",
    "opt_paired_config = OptPairedConfig(lr=LR_PAIRED)\n",
    "\n",
    "train_config = TrainConfig(\n",
    "    seed=SEED, steps_to=MAX_STEPS, paired_batch_size=PAIRED_BATCH_SIZE, unpaired_batch_size=UNPAIRED_BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "339530f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(train_config.seed)\n",
    "np.random.seed(train_config.seed)\n",
    "random.seed(train_config.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c47f945",
   "metadata": {},
   "source": [
    "## 3. Create data and samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "96f8e018",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.datasets import get_latents\n",
    "from src.samplers.base import TensorSampler\n",
    "from src.utils.paired import get_paired_sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d07e5c1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test = get_latents(INPUT_DATA, dtype=dtype)\n",
    "Y_train, Y_test = get_latents(TARGET_DATA, dtype=dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8fb361e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = TensorSampler(X_train.to(dtype), device=device)\n",
    "Y_sampler = TensorSampler(Y_train.to(dtype), device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5fa4888e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from_dir = f\"./datasets/FFHQ/pairs/{INPUT_DATA}->{TARGET_DATA}\"\n",
    "X_paired_train_ = torch.load(os.path.join(from_dir, f\"X_train.pt\"), map_location=device, weights_only=True).to(dtype)\n",
    "Y_paired_train_ = torch.load(os.path.join(from_dir, f\"Y_train.pt\"), map_location=device, weights_only=True).to(dtype)\n",
    "\n",
    "X_paired_test_ = torch.load(os.path.join(from_dir, f\"X_test.pt\"), map_location=device, weights_only=True).to(dtype)\n",
    "Y_paired_test_ = torch.load(os.path.join(from_dir, f\"Y_test.pt\"), map_location=device, weights_only=True).to(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2c3ed6f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_paired_train = X_paired_train_[:2000]\n",
    "Y_paired_train = Y_paired_train_[:2000]\n",
    "\n",
    "X_paired_test = X_paired_test_[2000:4000]\n",
    "Y_paired_test = Y_paired_test_[2000:4000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "048c5d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from_dir = f\"./datasets/FFHQ/\"\n",
    "# paired_data = torch.load(os.path.join(from_dir, f\"kp_traj_ffhq_gen.pt\"), map_location=device, weights_only=True).to(\n",
    "#     dtype\n",
    "# )\n",
    "\n",
    "# X_paired = paired_data[:, -1, :]\n",
    "# Y_paired = paired_data[:, 0, :]\n",
    "\n",
    "# N = X_paired.shape[0]\n",
    "# perm = torch.randperm(N)\n",
    "\n",
    "# test_size = int(0.1 * N)\n",
    "# test_idx = perm[:test_size]\n",
    "# train_idx = perm[test_size:]\n",
    "\n",
    "# # X_paired_train = X_paired[train_idx]\n",
    "# # Y_paired_train = Y_paired[train_idx]\n",
    "# X_paired_train = X_paired\n",
    "# Y_paired_train = Y_paired\n",
    "\n",
    "# # X_paired_test = X_paired[test_idx]\n",
    "# # Y_paired_test = Y_paired[test_idx]\n",
    "# X_paired_test = X_paired\n",
    "# Y_paired_test = Y_paired\n",
    "\n",
    "# print(f\"Total pairs: {N}. Train: {X_paired_train.shape[0]}, Test: {X_paired_test.shape[0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0ee08414",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_train_sampler = get_paired_sampler(\n",
    "    X_paired_train, Y_paired_train, train_config.paired_batch_size, P_XY_PAIRED_SAMPLES, device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "efccf0b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32816"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Q_X_UNPAIRED_SAMPLES = min(Q_X_UNPAIRED_SAMPLES, X_train.shape[0])\n",
    "Q_X_UNPAIRED_SAMPLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0820974b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10762"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "R_Y_UNPAIRED_SAMPLES = min(R_Y_UNPAIRED_SAMPLES, Y_train.shape[0])\n",
    "R_Y_UNPAIRED_SAMPLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "afe419b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if Q_X_UNPAIRED_SAMPLES > 0:\n",
    "    source_data = X_sampler.sample(Q_X_UNPAIRED_SAMPLES)\n",
    "    usd_sampler = DatasetSampler(source_data, device=device) # usd - unpaired source data\n",
    "else:\n",
    "    usd_sampler = DatasetSampler(X_paired_train, device=device)\n",
    "\n",
    "if R_Y_UNPAIRED_SAMPLES > 0:\n",
    "    target_data = Y_sampler.sample(R_Y_UNPAIRED_SAMPLES)\n",
    "    utd_sampler = DatasetSampler(target_data, device=device) # utd - unpaired target data\n",
    "else:\n",
    "    utd_sampler = DatasetSampler(Y_paired_train, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd8d737",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "64b2ddbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.costs.lse import BatchedLSECost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "32ad8aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e538192b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cost = MLPLSECost(**cost_config.model_dump())\n",
    "# cost = BatchedLSECost(u, log_v_m_net, m_potentials=M_POTENTIALS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c9172cfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "light_gcot_model = GMMEOT(\n",
    "    y_dim=Y_DIM,\n",
    "    n_potentials=N_POTENTIALS,\n",
    "    cost=cost,\n",
    ").to(dtype)\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    light_gcot_model.init_a_by_samples(Y_sampler.sample(N_POTENTIALS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "329756e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For EMA update\n",
    "if train_config.ema_update:\n",
    "    model_copy = GMMEOT(\n",
    "    y_dim=Y_DIM,\n",
    "    n_potentials=N_POTENTIALS,\n",
    "    cost=cost,\n",
    ").to(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22707c22",
   "metadata": {},
   "source": [
    "## 5. Optimizers initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b4c4ee69",
   "metadata": {},
   "outputs": [],
   "source": [
    "unpaired_params_to_update = [light_gcot_model._log_w_n, light_gcot_model._a_n, light_gcot_model._log_A_n]\n",
    "\n",
    "D_opt_unpaired = torch.optim.Adam(unpaired_params_to_update, **opt_unpaired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "79c060ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_paired = torch.optim.Adam(light_gcot_model.cost.parameters(), **opt_paired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7a0d317b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: refactor this config\n",
    "EXP_NAME = (\n",
    "    \"GMMEOT_ALAE_\"\n",
    "    + f\"FROM_{INPUT_DATA}_\"\n",
    "    + f\"TO_{TARGET_DATA}_\"\n",
    "    + f\"P_XY_PAIRED_{P_XY_PAIRED_SAMPLES}_\"\n",
    "    + f\"Q_X_UNPAIRED_{Q_X_UNPAIRED_SAMPLES}_\"\n",
    "    + f\"R_Y_UNPAIRED_{R_Y_UNPAIRED_SAMPLES}_\"\n",
    "    + f\"LR_PAIRED_{opt_paired_config.lr}_\"\n",
    "    + f\"LR_UNPAIRED_{opt_unpaired_config.lr}_\"\n",
    "    + f\"SEED_{SEED}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    D_LR_PAIRED=opt_paired_config.lr,\n",
    "    D_LR_UNPAIRED=opt_unpaired_config.lr,\n",
    "    BATCH_SIZE=train_config.unpaired_batch_size,\n",
    "    P_XY_PAIRED_SAMPLES=P_XY_PAIRED_SAMPLES,\n",
    "    Q_X_UNPAIRED_SAMPLES=Q_X_UNPAIRED_SAMPLES,\n",
    "    R_Y_UNPAIRED_SAMPLES=R_Y_UNPAIRED_SAMPLES,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b726194c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train_config.steps_from > 0:\n",
    "    D_opt_unpaired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{train_config.steps_from}.pt\")))\n",
    "    D_opt_paired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_paired_{train_config.steps_from}.pt\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71524ecb",
   "metadata": {},
   "source": [
    "## 6. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b695494",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1;38;5;214mCOMET WARNING:\u001b[0m As you are running in a Jupyter environment, you will need to call `experiment.end()` when finished to ensure all metrics and code are logged before exiting.\n"
     ]
    }
   ],
   "source": [
    "experiment = Experiment(\n",
    "    project_name=\"Light-GCOT-ALE\",\n",
    "    auto_output_logging=False,\n",
    "    parse_args=False,\n",
    ")\n",
    "experiment.set_name(EXP_NAME)\n",
    "experiment.log_parameters(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adf4c2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for step in tqdm(range(train_config.steps_from, train_config.steps_to)):\n",
    "    # training loop\n",
    "    D_opt_unpaired.zero_grad()\n",
    "\n",
    "    X = usd_sampler.sample(train_config.unpaired_batch_size)\n",
    "    Y = utd_sampler.sample(train_config.unpaired_batch_size)\n",
    "\n",
    "    output_unpaired = light_gcot_model.compute_unpaired_loss(X, Y)\n",
    "    D_loss_unpaired = output_unpaired[\"loss\"]\n",
    "\n",
    "    experiment.log_metric(\"Unpaired loss\", D_loss_unpaired.item(), step=step)\n",
    "\n",
    "    D_opt_paired.zero_grad()\n",
    "    X_paired, Y_paired = pd_train_sampler.sample(train_config.paired_batch_size)\n",
    "    \n",
    "    output_paired = light_gcot_model.compute_paired_loss(X_paired, Y_paired)\n",
    "    D_loss_paired = output_paired[\"loss\"]\n",
    "\n",
    "    experiment.log_metric(\"Paired loss\", D_loss_paired.item(), step=step)\n",
    "\n",
    "    D_loss = D_loss_unpaired + D_loss_paired\n",
    "    D_loss.backward()\n",
    "    D_opt_paired.step()\n",
    "    D_opt_unpaired.step()\n",
    "\n",
    "    if train_config.ema_update:\n",
    "        update_average(model_copy, light_gcot_model, 0.99)\n",
    "        light_gcot_model = model_copy\n",
    "    else:\n",
    "        light_gcot_model = light_gcot_model\n",
    "\n",
    "    experiment.log_metric(\n",
    "        \"Train paired loss\",\n",
    "        compute_loss(light_gcot_model, X_paired_train, Y_paired_train, X_paired_train, Y_paired_train),\n",
    "        step=step,\n",
    "    )\n",
    "    experiment.log_metric(\n",
    "        \"Test paired loss\",\n",
    "        compute_loss(light_gcot_model, X_paired_test, Y_paired_test, X_paired_test, Y_paired_test),\n",
    "        step=step,\n",
    "    )\n",
    "\n",
    "    experiment.log_metric(\"-f^c(x)\", -output_unpaired[\"f_c\"].mean().item(), step=step)\n",
    "    experiment.log_metric(\"-f(y)\", -output_unpaired[\"f\"].mean().item(), step=step)\n",
    "    experiment.log_metric(\"lam_min(A_n)\", torch.min(output_unpaired[\"A_n\"]).item(), step=step)\n",
    "    experiment.log_metric(\"lam_max(A_n)\", torch.max(output_unpaired[\"A_n\"]).item(), step=step)\n",
    "\n",
    "    if step % train_config.plot_every == 0:\n",
    "        torch.save(light_gcot_model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{step}.pt\"))\n",
    "\n",
    "torch.save(light_gcot_model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{MAX_STEPS}.pt\"))\n",
    "torch.save(D_opt_paired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_paired_{MAX_STEPS}.pt\"))\n",
    "torch.save(D_opt_unpaired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{MAX_STEPS}.pt\"))\n",
    "\n",
    "experiment.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32f96e03",
   "metadata": {},
   "source": [
    "# 7. Light-SBM training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "411d5d3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "a28906ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 0.1\n",
    "lr = 1e-3\n",
    "\n",
    "n_potentials = 10\n",
    "is_diag = True\n",
    "S_init = 0.1\n",
    "\n",
    "max_iter = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "dc1eba4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "light_sbm = LightSBM(dim=X_DIM, n_potentials=n_potentials, epsilon=eps, S_diagonal_init=S_init, is_diagonal=is_diag)\n",
    "\n",
    "light_sbm.to(device)\n",
    "light_sbm_opt = torch.optim.Adam(light_sbm.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "4a92968f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, max_iter, eps, opt, batch_size=512, safe_t=1e-2, device=device):\n",
    "    \n",
    "    pbar = tqdm(range(1, max_iter + 1))\n",
    "    \n",
    "    for i in pbar:\n",
    "        \n",
    "        x_0_samples = X_sampler.sample(batch_size).to(device)      \n",
    "        x_1_samples = Y_sampler.sample(batch_size).to(device)\n",
    "        \n",
    "        t = torch.rand([batch_size, 1], device=device) * (1 - safe_t)\n",
    "        \n",
    "        x_t = x_1_samples * t + x_0_samples * (1 - t) + torch.sqrt(eps * t * (1 - t)) * torch.randn_like(x_0_samples)\n",
    "                \n",
    "        predicted_drift = model.get_drift(x_t, t.squeeze())\n",
    "        \n",
    "        loss_plan = (model.get_log_C(x_0_samples) - model.get_log_potential(x_1_samples)).mean()\n",
    "        \n",
    "        target_drift = (x_1_samples - x_t) / (1 - t)\n",
    "        \n",
    "        loss = F.mse_loss(target_drift, predicted_drift)\n",
    "        \n",
    "        opt.zero_grad()\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        opt.step()\n",
    "        \n",
    "        pbar.set_description(f'Loss : {loss.item()} Plan Loss: {loss_plan.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "b50e57e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                | 0/20000 [00:00<?, ?it/s]/trinity/home/m.persiyanov/miniconda3/envs/light-gcot/lib/python3.12/site-packages/torch/utils/_device.py:78: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  return func(*args, **kwargs)\n",
      "Loss : 1.332700252532959 Plan Loss: 4024.15673828125: 100%|█████████████████████████████████████████████████████████████████████████████| 20000/20000 [02:27<00:00, 136.00it/s]\n"
     ]
    }
   ],
   "source": [
    "train(light_sbm, max_iter, eps, light_sbm_opt, batch_size=512, safe_t=1e-2, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18abb913",
   "metadata": {},
   "source": [
    "# FSBM from checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "dfcba2db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "from pathlib import Path\n",
    "\n",
    "sys.path.append(\"../FSBM\")\n",
    "from fsbm.dataset import get_dist_boundary\n",
    "from fsbm.utils import restore_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "05917307",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRNSF = \"gen\"\n",
    "FSBM_SEED = 2 # 0, 1, 2\n",
    "\n",
    "ckpt_dir = f\"../FSBM/outputs/runs/ffhq_{TRNSF}/\"\n",
    "\n",
    "if FSBM_SEED == 0:\n",
    "    subdir = \"2025.11.27/162837\"\n",
    "elif FSBM_SEED == 1:\n",
    "    subdir = \"2025.12.02/124004\"\n",
    "elif FSBM_SEED == 2:\n",
    "    subdir = \"2025.12.02/174536\"\n",
    "else:\n",
    "    raise ValueError(f\"Unknown SEED: {FSBM_SEED}!\")\n",
    "\n",
    "cfg = OmegaConf.load(os.path.join(ckpt_dir, f\"{subdir}/.hydra/config.yaml\"))\n",
    "ckpt_file_path = os.path.join(ckpt_dir, f\"{subdir}/checkpoints/last.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "fa22809c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint keys: dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])\n"
     ]
    }
   ],
   "source": [
    "## Load model\n",
    "fsbm_model, cfg = restore_model(ckpt_file_path, device=device)\n",
    "\n",
    "fsbm_model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f4af22e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from alae_ffhq_inference import decode, load_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f8da9b59",
   "metadata": {},
   "outputs": [],
   "source": [
    "alae_model = load_model(\"./ALAE/configs/ffhq.yaml\", training_artifacts_dir=\"./ALAE/training_artifacts/ffhq/\").to(\n",
    "    device\n",
    ").to(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3241a020",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:\n",
    "    normalized = tensor / 2 + 0.5\n",
    "    return normalized.clamp_(0, 1)\n",
    "\n",
    "def to_uint8(normalized_tensor: torch.Tensor) -> torch.Tensor:\n",
    "    return normalized_tensor.mul(255).add_(0.5).clamp_(0, 255).to(torch.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "efa3e3f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = X_paired_test[:10]\n",
    "y = Y_paired_test[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f136790f",
   "metadata": {},
   "outputs": [],
   "source": [
    "direction = \"bwd\"\n",
    "output = fsbm_model.sample(x, log_steps=20, nfe=1000, direction=direction)\n",
    "y_pred = output[\"xs\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "ed9c0f21",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded = []\n",
    "\n",
    "for traj in y_pred:\n",
    "    decoded.append(normalize_tensor(decode(alae_model, traj)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "e7522fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_show = min(5, y_pred.shape[0])\n",
    "T = y_pred.shape[1]\n",
    "\n",
    "fig, axes = plt.subplots(num_show, T, figsize=(2*T, 2*num_show))\n",
    "\n",
    "for s in range(num_show):\n",
    "    for t in range(T):\n",
    "        axes[s, t].imshow(decoded[s][t].permute(1,2,0).cpu().numpy())\n",
    "        axes[s, t].axis(\"off\")\n",
    "        if s == 0:\n",
    "            axes[s, t].set_title(f\"t={t}\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{INPUT_DATA}->{TARGET_DATA}_FSBM_{direction}.png', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "098fc231",
   "metadata": {},
   "source": [
    "# 8. Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9218752c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics.image import StructuralSimilarityIndexMeasure\n",
    "from torchmetrics.image.fid import FrechetInceptionDistance\n",
    "from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "6517b92f",
   "metadata": {},
   "outputs": [],
   "source": [
    "EVAL_MODEL_STEP = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "99b9036b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "light_gcot_model.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_{EVAL_MODEL_STEP}.pt\"), map_location=device, weights_only=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "133f3eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_model(model: torch.nn.Module, model_name: str) -> tuple[float, float, float]:\n",
    "    loss_fid = FrechetInceptionDistance().to(device)\n",
    "    loss_ssim = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0)).to(device)\n",
    "    loss_lpip = LearnedPerceptualImagePatchSimilarity(net_type='alex').to(device)\n",
    "    \n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    alae_model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        sampling_batch_size = 128\n",
    "        num_samples = min(len(X_test), len(Y_test))\n",
    "        print(f\"Number of X_test samples: {len(X_test)}\")\n",
    "        print(f\"Number of Y_test samples: {len(Y_test)}\")\n",
    "        print(f\"Using {num_samples} paired samples\")\n",
    "\n",
    "        num_iters = (num_samples + sampling_batch_size - 1) // sampling_batch_size\n",
    "\n",
    "        for i in tqdm(range(num_iters)):\n",
    "            start = i * sampling_batch_size\n",
    "            end   = min(start + sampling_batch_size, num_samples)\n",
    "        \n",
    "            # sub_batch_x = X_paired_test[sampling_batch_size * i : sampling_batch_size * (i + 1)]\n",
    "            # sub_batch_y = Y_paired_test[sampling_batch_size * i : sampling_batch_size * (i + 1)]\n",
    "            sub_batch_x = X_test[start:end].to(device)\n",
    "            sub_batch_y = Y_test[start:end].to(device)\n",
    "\n",
    "            if \"FSBM\" in model_name:\n",
    "                output = model.sample(sub_batch_x, log_steps=20, nfe=1000, direction=\"fwd\")\n",
    "                y_pred = output[\"xs\"][:, -1, :]\n",
    "            else:\n",
    "                y_pred = model(sub_batch_x)\n",
    "            normalized_pred_images = normalize_tensor(decode(alae_model, y_pred))\n",
    "            normalized_true_images = normalize_tensor(decode(alae_model, sub_batch_y))\n",
    "\n",
    "            loss_fid.update(to_uint8(normalized_pred_images), real=False)\n",
    "            loss_fid.update(to_uint8(normalized_true_images), real=True)\n",
    "\n",
    "            loss_ssim.update(normalized_pred_images, normalized_true_images)\n",
    "            loss_lpip.update(normalized_pred_images, normalized_true_images)\n",
    "\n",
    "            # Explicitly free sub-batches to release GPU memory\n",
    "            del sub_batch_x, sub_batch_y, y_pred, normalized_pred_images, normalized_true_images\n",
    "            torch.cuda.empty_cache()\n",
    "    \n",
    "    loss_fid_out = loss_fid.compute()\n",
    "    loss_ssim_out = loss_ssim.compute()\n",
    "    loss_lpip_out = loss_lpip.compute()\n",
    "    \n",
    "    torch.save(loss_fid_out, os.path.join(OUTPUT_PATH, f\"FID_{model_name}_{EVAL_MODEL_STEP}.pt\"))\n",
    "    torch.save(loss_ssim_out, os.path.join(OUTPUT_PATH, f\"SSIM_{model_name}_{EVAL_MODEL_STEP}.pt\"))\n",
    "    torch.save(loss_lpip_out, os.path.join(OUTPUT_PATH, f\"LPIP_{model_name}_{EVAL_MODEL_STEP}.pt\"))\n",
    "    \n",
    "    return loss_fid_out, loss_ssim_out, loss_lpip_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "c3076388",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.load(os.path.join(OUTPUT_PATH, f\"FID_light-gcot_{EVAL_MODEL_STEP}.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "0093493c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "1ef57f87",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of X_test samples: 5572\n",
      "Number of Y_test samples: 4351\n",
      "Using 4351 paired samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [01:58<00:00,  3.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID: 9.207294464111328\n",
      "SSIM: 0.5312055945396423\n",
      "LPIPS: 0.553896427154541\n"
     ]
    }
   ],
   "source": [
    "loss_fid, loss_ssim, loss_lpip = eval_model(light_gcot_model, f\"light-gcot-{SEED}\")\n",
    "print(f\"FID: {loss_fid}\")\n",
    "print(f\"SSIM: {loss_ssim}\")\n",
    "print(f\"LPIPS: {loss_lpip}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aff5438",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID: 9.339266777038574 +- 0.10169362405831284\n",
      "SSIM: 0.5314827362696329 +- 0.00021199561661770747\n",
      "LPIPS: 0.5531338055928549 +- 0.0005572825875298105\n"
     ]
    }
   ],
   "source": [
    "# SEEDS 42, 43, 44\n",
    "light_gcot_fids = np.array([9.45474910736084, 9.355756759643555, 9.207294464111328])\n",
    "light_gcot_ssims = np.array([0.5317203402519226, 0.531522274017334, 0.5312055945396423])\n",
    "light_gcod_lpips = np.array([0.5529246926307678, 0.5525802969932556, 0.553896427154541])\n",
    "print(f\"FID: {light_gcot_fids.mean()} +- {light_gcot_fids.std()}\")\n",
    "print(f\"SSIM: {light_gcot_ssims.mean()} +- {light_gcot_ssims.std()}\")\n",
    "print(f\"LPIPS: {light_gcod_lpips.mean()} +- {light_gcod_lpips.std()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "56138917",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of X_test samples: 5572\n",
      "Number of Y_test samples: 4351\n",
      "Using 4351 paired samples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [02:34<00:00,  4.54s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID: 10.184301376342773\n",
      "SSIM: 0.5239837765693665\n",
      "LPIPS: 0.5625840425491333\n"
     ]
    }
   ],
   "source": [
    "loss_fid, loss_ssim, loss_lpip = eval_model(fsbm_model, f\"FSBM_{FSBM_SEED}\")\n",
    "print(f\"FID: {loss_fid}\")\n",
    "print(f\"SSIM: {loss_ssim}\")\n",
    "print(f\"LPIPS: {loss_lpip}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "d9241a88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID: 10.237724622090658 +- 0.6478091607417709\n",
      "SSIM: 0.5236614147822062 +- 0.0005105691526271854\n",
      "LPIPS: 0.5624561309814453 +- 0.00034899172904372333\n"
     ]
    }
   ],
   "source": [
    "# SEEDS 0, 1, 2 \n",
    "fsbm_fids = np.array([9.47238540649414, 11.056487083435059, 10.184301376342773])\n",
    "fsbm_ssims = np.array([0.5240597724914551, 0.5229406952857971, 0.5239837765693665])\n",
    "fsbm_lpips = np.array([0.561979353427887, 0.5628049969673157, 0.5625840425491333])\n",
    "print(f\"FID: {fsbm_fids.mean()} +- {fsbm_fids.std()}\")\n",
    "print(f\"SSIM: {fsbm_ssims.mean()} +- {fsbm_ssims.std()}\")\n",
    "print(f\"LPIPS: {fsbm_lpips.mean()} +- {fsbm_lpips.std()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "661e4e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss_fid_light_sbm, loss_ssim_light_sbm, loss_lpip_light_sbm = eval_model(light_sbm, \"light-sbm\")\n",
    "# print(f\"FID: {loss_fid_light_sbm}\")\n",
    "# print(f\"SSIM: {loss_ssim_light_sbm}\")\n",
    "# print(f\"LPIPS: {loss_lpip_light_sbm}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ee44087",
   "metadata": {},
   "source": [
    "# 8. Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54ac938c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing: Our\n",
      "Processing: Our\n",
      "Processing: Our\n",
      "Processing: Our\n",
      "Processing: Our\n",
      "Processing: FSBM\n",
      "Processing: FSBM\n",
      "Processing: FSBM\n",
      "Processing: FSBM\n",
      "Processing: FSBM\n"
     ]
    }
   ],
   "source": [
    "# Parameters\n",
    "num_images = 10        # Number of test images to show\n",
    "num_gen = 5            # Number of generated versions per image\n",
    "models = [light_gcot_model, fsbm_model]\n",
    "model_names = [\"Our\", \"FSBM\"]\n",
    "\n",
    "# Select random test samples\n",
    "x = X_paired_test[indices]\n",
    "y = Y_paired_test[indices]\n",
    "\n",
    "\n",
    "# Decode input and target images\n",
    "init_img = normalize_tensor(decode(alae_model, x))\n",
    "true_img = normalize_tensor(decode(alae_model, y))\n",
    "\n",
    "# Generate predictions for each model\n",
    "all_model_preds = []  # List to hold predictions from each model\n",
    "\n",
    "for model, model_name in zip(models, model_names):\n",
    "    model_preds = []\n",
    "    for _ in range(num_gen):\n",
    "        with torch.no_grad():\n",
    "            print(f\"Processing: {model_name}\")\n",
    "            if model_name == \"FSBM\":\n",
    "                output = model.sample(x, log_steps=20, nfe=1000, direction=\"fwd\")\n",
    "                y_pred = output[\"xs\"][:, -1, :]\n",
    "            else:\n",
    "                y_pred = model(x)\n",
    "            \n",
    "            decoded = normalize_tensor(decode(alae_model, y_pred))\n",
    "            model_preds.append(decoded)\n",
    "    model_preds = torch.stack(model_preds, dim=1)  # [num_images, num_gen, C, H, W]\n",
    "    all_model_preds.append(model_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "f76394b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to numpy arrays for plotting\n",
    "init_img_np = init_img.cpu().permute(0, 2, 3, 1).numpy()      # [num_images, H, W, C]\n",
    "true_img_np = true_img.cpu().permute(0, 2, 3, 1).numpy()      # [num_images, H, W, C]\n",
    "all_model_preds_np = [\n",
    "    preds.cpu().permute(0, 1, 3, 4, 2).numpy()                # [num_images, num_gen, H, W, C]\n",
    "    for preds in all_model_preds\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "f5781297",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting\n",
    "cols = 2 + num_gen * len(models)\n",
    "fig, axes = plt.subplots(\n",
    "    num_images,\n",
    "    cols,\n",
    "    figsize=(cols, num_images * 1.5),\n",
    "    dpi=200\n",
    ")\n",
    "\n",
    "for i in range(num_images):\n",
    "    # Input image\n",
    "    axes[i, 0].imshow(init_img_np[i])\n",
    "    axes[i, 0].set_title('Input' if i == 0 else '')\n",
    "    axes[i, 0].axis('off')\n",
    "\n",
    "    # Target image\n",
    "    axes[i, 1].imshow(true_img_np[i])\n",
    "    axes[i, 1].set_title('Target' if i == 0 else '')\n",
    "    axes[i, 1].axis('off')\n",
    "\n",
    "    # Generated images from each model\n",
    "    col_idx = 2\n",
    "    for m_idx, model_preds in enumerate(all_model_preds_np):\n",
    "        for g_idx in range(num_gen):\n",
    "            axes[i, col_idx].imshow(model_preds[i, g_idx])\n",
    "            if i == 0:\n",
    "                axes[i, col_idx].set_title(f'{model_names[m_idx]}')\n",
    "            axes[i, col_idx].axis('off')\n",
    "            col_idx += 1\n",
    "\n",
    "plt.tight_layout(pad=0.5)\n",
    "plt.savefig(f'{INPUT_DATA}->{TARGET_DATA}_full.png', bbox_inches='tight')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "22820e18",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'WOMAN->MAN_full.png'"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f'{INPUT_DATA}->{TARGET_DATA}_full.png'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "f78ac037",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_indices = [0, 1, 2, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "cf1f73b1",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'selected_model_preds_np' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[47], line 24\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[38;5;66;03m# Generated images from each model\u001b[39;00m\n\u001b[1;32m     23\u001b[0m col_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m_idx, model_preds \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[43mselected_model_preds_np\u001b[49m):\n\u001b[1;32m     25\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m g_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_gen):\n\u001b[1;32m     26\u001b[0m         axes[i, col_idx]\u001b[38;5;241m.\u001b[39mimshow(model_preds[i, g_idx])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'selected_model_preds_np' is not defined"
     ]
    }
   ],
   "source": [
    "# Plotting\n",
    "cols = 2 + num_gen * len(models)\n",
    "fig, axes = plt.subplots(\n",
    "    len(selected_indices),\n",
    "    cols,\n",
    "    figsize=(cols, len(selected_indices) * 1.5),\n",
    "    dpi=200\n",
    ")\n",
    "\n",
    "for i in range(num_images):\n",
    "    # Input image\n",
    "    if i in selected_indices:\n",
    "        axes[i, 0].imshow(init_img_np[i])\n",
    "        axes[i, 0].set_title('Input' if i == 0 else '')\n",
    "        axes[i, 0].axis('off')\n",
    "\n",
    "        # Target image\n",
    "        axes[i, 1].imshow(true_img_np[i])\n",
    "        axes[i, 1].set_title('Target' if i == 0 else '')\n",
    "        axes[i, 1].axis('off')\n",
    "\n",
    "        # Generated images from each model\n",
    "        col_idx = 2\n",
    "        for m_idx, model_preds in enumerate(selected_model_preds_np):\n",
    "            for g_idx in range(num_gen):\n",
    "                axes[i, col_idx].imshow(model_preds[i, g_idx])\n",
    "                if i == 0:\n",
    "                    axes[i, col_idx].set_title(f'{model_names[m_idx]}')\n",
    "                axes[i, col_idx].axis('off')\n",
    "                col_idx += 1\n",
    "\n",
    "plt.tight_layout(pad=0.5)\n",
    "plt.savefig(f'{INPUT_DATA}->{TARGET_DATA}_selected.png', bbox_inches='tight')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16a0049f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "107b9259",
   "metadata": {},
   "source": [
    "# Saving images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "383feb8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import trange\n",
    "\n",
    "from torchvision.utils import make_grid, save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "f8afd7a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_row(x, y, model_preds, model_names, num_gen, out_dir, idx):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "    row = [x, y]  # first two images\n",
    "\n",
    "    # each entry: [batch, num_gen, C,H,W]\n",
    "    for m_idx, _ in enumerate(model_names):\n",
    "        for g in range(num_gen):\n",
    "            row.append(model_preds[m_idx][g])  #  (C,H,W)\n",
    "\n",
    "    row = torch.stack(row, dim=0)  # (N,C,H,W)\n",
    "    grid = make_grid(row, nrow=row.shape[0])\n",
    "\n",
    "    save_path = os.path.join(out_dir, f\"row_{idx:05d}.png\")\n",
    "    save_image(grid, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "5559cfcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_all_rows_batched(\n",
    "    X_test: torch.Tensor,\n",
    "    Y_test: torch.Tensor,\n",
    "    models: list[nn.Module],\n",
    "    model_names: list[str],\n",
    "    alae_model: nn.Module,\n",
    "    batch_size: int = 64,\n",
    "    num_gen: int = 1,\n",
    "    out_dir: str = \"generated_rows\",\n",
    "    indices: None | list[int] = None,  # NEW\n",
    "):\n",
    "    device = next(models[0].parameters()).device  # same GPU for all models\n",
    "\n",
    "    if indices is not None:\n",
    "        indices = sorted(list(indices))\n",
    "        num_samples = len(indices)\n",
    "    else:\n",
    "        num_samples = min(len(X_test), len(Y_test))\n",
    "\n",
    "    num_batches = (num_samples + batch_size - 1) // batch_size\n",
    "\n",
    "    idx_global = 0  # global counter for saving row_NNNNN.png files\n",
    "\n",
    "    for batch_idx in range(num_batches):\n",
    "        start = batch_idx * batch_size\n",
    "        end = min((batch_idx + 1) * batch_size, num_samples)\n",
    "\n",
    "        print(f\"Batch {batch_idx+1}/{num_batches} → samples {start}:{end}\")\n",
    "\n",
    "        # Load batch\n",
    "        if indices is None:\n",
    "            batch_ids = list(range(start, end))\n",
    "        else:\n",
    "            batch_ids = indices[start:end]\n",
    "\n",
    "        x = X_test[batch_ids].to(device)\n",
    "        y = Y_test[batch_ids].to(device)\n",
    "\n",
    "        # Decode X and Y (once per batch)\n",
    "        x_dec = normalize_tensor(decode(alae_model, x))\n",
    "        y_dec = normalize_tensor(decode(alae_model, y))\n",
    "\n",
    "        # Predict with each model\n",
    "        batch_model_preds = []  # list of shape [num_models] -> [batch, num_gen, C,H,W]\n",
    "\n",
    "        for model, mname in zip(models, model_names):\n",
    "            preds_gens = []\n",
    "\n",
    "            for _ in range(num_gen):\n",
    "                with torch.no_grad():\n",
    "                    if mname == \"FSBM\":\n",
    "                        out = model.sample(x, log_steps=20, nfe=1000, direction=\"fwd\")\n",
    "                        y_pred = out[\"xs\"][:, -1, :]\n",
    "                    else:\n",
    "                        y_pred = model(x)\n",
    "\n",
    "                    y_decoded = normalize_tensor(decode(alae_model, y_pred))\n",
    "                    preds_gens.append(y_decoded)\n",
    "\n",
    "            preds_gens = torch.stack(preds_gens, dim=1)  # [batch, num_gen, C,H,W]\n",
    "            batch_model_preds.append(preds_gens)\n",
    "\n",
    "        # Save each sample row\n",
    "        for local_i in trange(len(batch_ids)):\n",
    "            real_idx = batch_ids[local_i]\n",
    "\n",
    "            x_i = x_dec[local_i]\n",
    "            y_i = y_dec[local_i]\n",
    "\n",
    "            preds_i = [m[local_i] for m in batch_model_preds]\n",
    "\n",
    "            save_row(x_i, y_i, preds_i, model_names=model_names, num_gen=num_gen, out_dir=out_dir, idx=real_idx)\n",
    "\n",
    "            idx_global += 1\n",
    "\n",
    "        # Clear memory\n",
    "        del x, y, x_dec, y_dec, batch_model_preds\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "0e735fe7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1/1 → samples 0:1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.20it/s]\n"
     ]
    }
   ],
   "source": [
    "models = [fsbm_model, light_gcot_model]\n",
    "model_names = [\"FSBM\", \"Our\"]\n",
    "\n",
    "generate_all_rows_batched(\n",
    "    X_paired_test,\n",
    "    Y_paired_test,\n",
    "    models=models,\n",
    "    model_names=model_names,\n",
    "    alae_model=alae_model,\n",
    "    batch_size=64,\n",
    "    num_gen=1,\n",
    "    out_dir=\"rows_output_right_order\",\n",
    "    indices=[454]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0c11f46",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
