{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f3bd7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd972eac",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6261775b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import argparse\n",
    "import copy\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.distributions as TD\n",
    "import torchvision.transforms as tr\n",
    "from PIL import Image\n",
    "from torchvision import datasets\n",
    "from tqdm import tqdm\n",
    "\n",
    "import wandb\n",
    "from configs.energy_based.model import EBMConfig\n",
    "from mnist2to3.utils import plot_diagnostics, plot_images, steps_counter\n",
    "from src.costs.convolutional import NonlocalCost, UNetCost, VanillaCost\n",
    "from src.costs.lse import MLPLSECost\n",
    "from src.costs.mlp_based import MLPCost\n",
    "from src.costs.nonlearnable import SquareCost\n",
    "from src.models.energy_based import EGEOT\n",
    "from src.plotting.distributions import plot_swiss_roll\n",
    "from src.potentials.mlp_based import MLPPotential\n",
    "from src.potentials.vanilla import NonlocalPotential, VanillaPotential\n",
    "from src.samplers.energy_based.sample_buffer import SampleBufferEgEOT\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.dataset.colored_mnist import (\n",
    "    apply_random_color,\n",
    "    download_digit_images,\n",
    "    get_paired_digits,\n",
    ")\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "from src.utils.train import compute_loss, update_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d3739161",
   "metadata": {},
   "outputs": [],
   "source": [
    "WANDB_PROJECT_NAME = \"eot\"\n",
    "DISCRETE_OT_DIR = \"../src/discreteot\"\n",
    "sys.path.append(DISCRETE_OT_DIR)\n",
    "from src.discreteot import DiscreteEOT_l2sq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c01f68bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = \"mnist2to3_s500_pVanilla_h0.01_P200\"\n",
    "EXP_DIR = \"./out_data/{}/\".format(EXP_NAME)\n",
    "# json file with experiment config\n",
    "CONFIG_FILE = \"./config_locker/{}.json\".format(EXP_NAME)\n",
    "\n",
    "FROM_ITERATION = 7000\n",
    "EVAL = True\n",
    "\n",
    "FULL_DEVICE = f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(FULL_DEVICE)\n",
    "USE_WANDB = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "535e12ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load experiment config\n",
    "with open(CONFIG_FILE) as file:\n",
    "    config = json.load(file)\n",
    "\n",
    "# make directory for saving results\n",
    "os.makedirs(EXP_DIR, exist_ok=True)\n",
    "for folder in [\"checkpoints\", \"shortrun\", \"longrun\", \"plots\", \"code\"]:\n",
    "    # os.mkdir(EXP_DIR + folder, exist_ok=True)\n",
    "    os.makedirs(EXP_DIR + folder, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "734098c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set seed for cpu and CUDA, get device\n",
    "# DEVICE SETTING\n",
    "if FULL_DEVICE.startswith(\"cuda\"):\n",
    "    device = \"cuda\"\n",
    "    GPU_DEVICE = int(FULL_DEVICE.split(\":\")[1])\n",
    "    torch.cuda.set_device(GPU_DEVICE)\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "torch.manual_seed(config[\"seed\"])\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(config[\"seed\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "53e79cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.set_default_device(device)\n",
    "dtype = torch.float64\n",
    "torch.torch.set_default_dtype(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b8b7aa0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.manual_seed(train_config.seed)\n",
    "# np.random.seed(train_config.seed)\n",
    "# random.seed(train_config.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b529e7d",
   "metadata": {},
   "source": [
    "## 2. Training Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a423ccec",
   "metadata": {},
   "outputs": [],
   "source": [
    "HREG = config[\"hreg\"]\n",
    "EMA_UPDATE = config[\"ema_update\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "05209422",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up potential\n",
    "potential_bank = {\"vanilla\": VanillaPotential, \"nonlocal\": NonlocalPotential}\n",
    "f = potential_bank[config[\"potential_type\"]](n_c=config[\"im_ch\"]).to(device)\n",
    "# set up optimizer\n",
    "optim_bank = {\"adam\": torch.optim.Adam, \"sgd\": torch.optim.SGD}\n",
    "if config[\"optimizer_type\"] == \"sgd\" and config[\"epsilon\"] > 0:\n",
    "    # scale learning rate according to langevin noise for invariant tuning\n",
    "    config[\"lr_init\"] *= (config[\"epsilon\"] ** 2) / 2\n",
    "    config[\"lr_min\"] *= (config[\"epsilon\"] ** 2) / 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8de9f49d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up cost\n",
    "cost_bank = {\n",
    "    \"vanilla\": VanillaCost,\n",
    "    \"nonlocal\": NonlocalCost,\n",
    "    \"unet\": UNetCost,\n",
    "}\n",
    "cost = cost_bank[config[\"cost_type\"]](n_c=config[\"im_ch\"]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "95f875b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = EBMConfig()\n",
    "model = EGEOT(\n",
    "    potential=f,\n",
    "    cost=cost,\n",
    "    sample_buffer=None,\n",
    "    config=model_config,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69f93484",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=config[\"lr_init\"])\n",
    "if FROM_ITERATION > 0:\n",
    "    model.load_state_dict(\n",
    "        torch.load(\n",
    "            Path(EXP_DIR) / \"checkpoints\" / f\"model_{FROM_ITERATION:>06d}.pth\", weights_only=True, map_location=device\n",
    "        )\n",
    "    )\n",
    "    optimizer.load_state_dict(\n",
    "        torch.load(\n",
    "            Path(EXP_DIR) / \"checkpoints\" / f\"optim_{FROM_ITERATION:>06d}.pth\", weights_only=True, map_location=device\n",
    "        )\n",
    "    )\n",
    "\n",
    "if EMA_UPDATE:\n",
    "    model_copy = copy.deepcopy(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "69d78178",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = \"MNIST\"\n",
    "TARGET_DATASET = \"MNIST\"\n",
    "SOURCE_DIGIT = 2\n",
    "TARGET_DIGIT = 3\n",
    "P_XY_PAIRED_SAMPLES = config[\"P_XY\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ebb81cfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_images: list[torch.Tensor] = download_digit_images(SOURCE_DATASET, SOURCE_DIGIT, 10000)\n",
    "target_images: list[torch.Tensor] = download_digit_images(TARGET_DATASET, TARGET_DIGIT, 20000)\n",
    "\n",
    "q_x_paired, q_y_paired = get_paired_digits(\n",
    "    source_images, target_images, P_XY_PAIRED_SAMPLES, hue_offset=120, device=device\n",
    ")\n",
    "\n",
    "q_x = torch.stack([apply_random_color(digit, 360 * torch.rand(1)) for digit in source_images]).to(device)\n",
    "q_y = torch.stack([apply_random_color(digit, 360 * torch.rand(1)) for digit in target_images]).to(device)\n",
    "\n",
    "print(f\"P_XY_PAIRED: {q_x_paired.shape}; Q_X_UNPAIRED: {q_x.shape}; R_Y_UNPAIRED: {q_y.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a1aac87a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize persistent images from noise (one persistent image for each data image)\n",
    "# s_t_0 is used when init_type == 'persistent' in sample_s_t()\n",
    "s_t_0 = 2 * torch.rand_like(q_x) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ec49777d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample batch from given array of images\n",
    "def sample_image_set(image_set: torch.Tensor, size: int = config[\"batch_size\"]):\n",
    "    rand_inds = torch.randperm(image_set.shape[0])[0:size]\n",
    "    return image_set[rand_inds], rand_inds\n",
    "\n",
    "\n",
    "################# DOT for init\n",
    "def solve_dot(X: torch.Tensor, Y: torch.Tensor, numitermax: int = 10000, verbose: bool = False):\n",
    "    DOT_DTYPE = \"torch64\"\n",
    "    DOT_NUMITERMAX = numitermax\n",
    "    DOT_VERBOSE = verbose\n",
    "    discr_eot = DiscreteEOT_l2sq(device=device, verbose=DOT_VERBOSE, numItermax=DOT_NUMITERMAX, dtype=DOT_DTYPE).solve(\n",
    "        X.view(X.size(0), -1), Y.view(Y.size(0), -1), HREG\n",
    "    )\n",
    "    x_inds = torch.arange(X.size(0))\n",
    "    y_inds = discr_eot.sample_by_indices(x_inds, return_indices=True)\n",
    "    y_image_subset = Y[y_inds]\n",
    "    return X, y_image_subset, (x_inds, y_inds)\n",
    "\n",
    "\n",
    "if config[\"shortrun_init\"] == \"persistentDOT\":\n",
    "    SC = steps_counter(s0=config[\"pDOT_update_step\"], s1=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c57e0bbf",
   "metadata": {},
   "source": [
    "## 3. Functions for Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3c7ad0bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample positive images from dataset distribution q_y (add noise to ensure min sd is at least langevin noise sd)\n",
    "def sample_q_y():\n",
    "    x_q_y = sample_image_set(q_y)[0]\n",
    "    return x_q_y + config[\"data_epsilon\"] * torch.randn_like(x_q_y)\n",
    "\n",
    "\n",
    "def sample_pairs():\n",
    "    x, inds = sample_image_set(q_x_paired)\n",
    "    y = q_y_paired[inds]\n",
    "    return x + config[\"data_epsilon\"] * torch.randn_like(x), y + config[\"data_epsilon\"] * torch.randn_like(y)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "793261e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get initial mcmc states for langevin updates (\"persistent\", \"data\", \"uniform\", or \"gaussian\")\n",
    "def sample_s_t_0(init_type) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    returns (y_samples, x_samples, indices)\n",
    "    \"\"\"\n",
    "    if init_type == \"persistent\":\n",
    "        y_image_subset, rand_inds = sample_image_set(s_t_0)\n",
    "        return y_image_subset, q_x[rand_inds], rand_inds\n",
    "    elif init_type == \"DOT\":\n",
    "        X, _ = sample_image_set(q_x)\n",
    "        Y, _ = sample_image_set(q_y)\n",
    "        x_image_subset, y_image_subset, _ = solve_dot(X, Y, numitermax=2000)\n",
    "        return y_image_subset, x_image_subset, None\n",
    "    elif init_type == \"source_data\":\n",
    "        x_image_subset, inds = sample_image_set(q_x)\n",
    "        return x_image_subset.clone().detach(), x_image_subset, (inds,)\n",
    "    elif init_type == \"target_data\":\n",
    "        x_image_subset, inds = sample_image_set(q_x)\n",
    "        y_image_subset, y_inds = sample_image_set(q_y)\n",
    "        return y_image_subset, x_image_subset, (inds, y_inds)\n",
    "    elif init_type == \"persistentDOT\":\n",
    "        y_image_subset, rand_inds = sample_image_set(s_t_0)\n",
    "        if next(SC):\n",
    "            X, x_inds = sample_image_set(q_x, size=1000)\n",
    "            Y, _ = sample_image_set(q_y, size=1000)\n",
    "            X, Y_dot, _ = solve_dot(X, Y, numitermax=10000)\n",
    "            s_t_0[x_inds] = Y_dot\n",
    "        return y_image_subset, q_x[rand_inds], rand_inds\n",
    "    elif init_type == \"uniform\":\n",
    "        x_image_subset, _ = sample_image_set(q_x)\n",
    "        noise_image = 2 * torch.rand([config[\"batch_size\"], config[\"im_ch\"], config[\"im_sz\"], config[\"im_sz\"]]) - 1\n",
    "        return noise_image.to(device), x_image_subset, None\n",
    "    elif init_type == \"gaussian\":\n",
    "        x_image_subset, _ = sample_image_set(q_x)\n",
    "        noise_image = torch.randn([config[\"batch_size\"], config[\"im_ch\"], config[\"im_sz\"], config[\"im_sz\"]])\n",
    "        return noise_image.to(device), x_image_subset, None\n",
    "    elif init_type == \"from_cost\":\n",
    "        x_image_subset, _ = sample_image_set(q_x)\n",
    "        y_s_t = model.cost.net(x_image_subset)\n",
    "        return y_s_t, x_image_subset, None\n",
    "    else:\n",
    "        raise RuntimeError('Invalid method for \"init_type\" (use \"persistent\", \"data\", \"uniform\", or \"gaussian\")')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "216377e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize and update images with langevin dynamics to obtain samples from finite-step MCMC distribution s_t\n",
    "# TODO: update_s_t_0 seems to be buffer?\n",
    "def sample_s_t(\n",
    "    model: EGEOT,\n",
    "    y_s_t_0: torch.Tensor,\n",
    "    x_s_t_0: torch.Tensor,\n",
    "    num_steps: int,\n",
    "    init_type: str,\n",
    "    s_t_0_inds: torch.Tensor | None = None,\n",
    "    update_s_t_0: bool = True,\n",
    "):\n",
    "    # initialize MCMC samples\n",
    "    # y_s_t_0, x_s_t_0, s_t_0_inds = sample_s_t_0()\n",
    "\n",
    "    # iterative langevin updates of MCMC samples\n",
    "    r_s_t = torch.zeros(1).to(device)  # variable r_s_t (Section 3.2) to record average gradient magnitude\n",
    "    cost_grad_s_t = torch.zeros(1).to(device)\n",
    "    for _ in tqdm(range(num_steps), leave=False):\n",
    "        f_prime = model.potential.grad_y(y_s_t_0)\n",
    "        cost_grad = model.cost.grad_y(x_s_t_0, y_s_t_0)\n",
    "        y_s_t_0 += (f_prime - cost_grad) / (2 * HREG) + config[\"epsilon\"] * torch.randn_like(y_s_t_0)\n",
    "        r_s_t += f_prime.view(f_prime.shape[0], -1).norm(dim=1).mean()\n",
    "        cost_grad_s_t += cost_grad.view(f_prime.shape[0], -1).norm(dim=1).mean()\n",
    "\n",
    "    if init_type == \"persistent\" and update_s_t_0:\n",
    "        # update persistent image bank\n",
    "        s_t_0.data[s_t_0_inds] = y_s_t_0.detach().data.clone()\n",
    "\n",
    "    return y_s_t_0.detach(), x_s_t_0, r_s_t.squeeze() / num_steps, cost_grad_s_t.squeeze() / num_steps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cefb08b",
   "metadata": {},
   "source": [
    "## 4. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "66c4c32c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# containers for diagnostic records (see Section 3)\n",
    "d_s_t_record = torch.zeros(config[\"num_train_iters\"]).to(\n",
    "    device\n",
    ")  # energy difference between positive and negative samples\n",
    "r_s_t_record = torch.zeros(config[\"num_train_iters\"]).to(\n",
    "    device\n",
    ")  # average image gradient magnitude along Langevin path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5ab6a7e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if USE_WANDB:\n",
    "    wandb.init(name=EXP_NAME, project=WANDB_PROJECT_NAME, reinit=True, config=config)\n",
    "    print(\"WandB has initialized.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "13d8dd8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if EVAL:\n",
    "    print(\"Evaluation has started.\")\n",
    "    print(\n",
    "        \"{:>6d}   Generating long-run samples. (L={:>6d} MCMC steps)\".format(\n",
    "            FROM_ITERATION + 1, config[\"num_longrun_steps\"]\n",
    "        )\n",
    "    )\n",
    "    N_PER = 6\n",
    "    NUM_EVAL_SAMPLES = 3\n",
    "    assert NUM_EVAL_SAMPLES > 0\n",
    "    _step = torch.tensor(360 / NUM_EVAL_SAMPLES)\n",
    "\n",
    "    _q_x_eval, _ = sample_image_set(torch.stack(source_images), NUM_EVAL_SAMPLES)\n",
    "    q_x_eval = torch.stack([apply_random_color(_q_x_eval[i], _step * i) for i in range(NUM_EVAL_SAMPLES)]).to(device)\n",
    "\n",
    "    _x_s_t_0 = q_x_eval.repeat_interleave(N_PER, 0)\n",
    "    y_s_t_0, x_s_t_0 = _x_s_t_0.clone(), _x_s_t_0\n",
    "    for init_type in [config[\"shortrun_init\"]]:  # [\"DOT\", \"persistent\", \"uniform\", \"source_data\", \"target_data\"]:\n",
    "        with torch.no_grad():\n",
    "            _model = model_copy if EMA_UPDATE else model\n",
    "            y_p_theta, x_p_theta, _, _ = sample_s_t(\n",
    "                _model,\n",
    "                y_s_t_0,\n",
    "                x_s_t_0,\n",
    "                num_steps=config[\"num_longrun_steps\"],\n",
    "                init_type=init_type,\n",
    "                update_s_t_0=False,\n",
    "            )\n",
    "\n",
    "        plot_images(\n",
    "            f\"{init_type} init\",\n",
    "            y_p_theta,\n",
    "            step=FROM_ITERATION + 1,\n",
    "            save_dir=Path(EXP_DIR) / \"longrun\",\n",
    "        )\n",
    "        torch.save(y_p_theta, Path(EXP_DIR) / \"longrun\" / f\"{init_type} init.pt\")\n",
    "        print(\"{:>6d}   Long-run samples for init {} saved.\".format(FROM_ITERATION + 1, init_type))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bc1d14c",
   "metadata": {},
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "df7ca3a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from torchvision.utils import make_grid\n",
    "from torchvision.transforms import functional, Normalize\n",
    "from mnist2to3.utils import tensor2image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "af3e069e",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_type = \"target_data\"\n",
    "N_PER = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "f9186108",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_200 = torch.load(\n",
    "    Path(\"out_data/mnist2to3_s500_pVanilla_h0.01_P200/longrun/\") / f\"{init_type} init.pt\",\n",
    "    weights_only=True,\n",
    "    map_location=device,\n",
    ")\n",
    "y_10 = torch.load(\n",
    "    Path(\"out_data/mnist2to3_s500_pVanilla_h0.01_P10/longrun/\") / f\"{init_type} init.pt\",\n",
    "    weights_only=True,\n",
    "    map_location=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "836ae8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_200 = torch.load(\n",
    "    Path(\"out_data/mnist2to3_s500_pVanilla_h0.01_P200/longrun/\") / f\"x.pt\", weights_only=True, map_location=device\n",
    ")\n",
    "# x_10 = torch.load(\n",
    "#     Path(\"out_data/mnist2to3_s500_pVanilla_h0.01_P10/longrun/\") / f\"x.pt\", weights_only=True, map_location=device\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "9f62732f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(\n",
    "    f\"init\",\n",
    "    x_200[::N_PER, :, :, :],\n",
    "    normalize=True,\n",
    "    nrow=1,\n",
    "    clamp=True,\n",
    "    step=FROM_ITERATION + 1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "60ff3b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_shape = y_200[0].shape\n",
    "NUM_EVAL_SAMPLES = 20\n",
    "\n",
    "\n",
    "final_images_indices = torch.tensor([1, 3, 4, 6, 7, 9, 11, 12, 13, 15]).to(device)\n",
    "NUM_FINAL_SAMPLES = len(final_images_indices)\n",
    "\n",
    "_step = torch.tensor(360 / NUM_EVAL_SAMPLES)\n",
    "_q_y_eval, _ = sample_image_set(torch.stack(target_images), NUM_EVAL_SAMPLES)\n",
    "q_y_eval = torch.stack([apply_random_color(_q_y_eval[i], _step * i + 120) for i in range(NUM_EVAL_SAMPLES)]).to(\n",
    "    device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "6079a7aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor1 = x_200[::N_PER, :, :, :][final_images_indices].view(NUM_FINAL_SAMPLES, 1, *image_shape)\n",
    "tensor2 = y_10[::N_PER, :, :, :][final_images_indices].view(NUM_FINAL_SAMPLES, 1, *image_shape)\n",
    "tensor3 = y_200[::N_PER, :, :, :][final_images_indices].view(NUM_FINAL_SAMPLES, 1, *image_shape)\n",
    "tensor4 = q_y_eval[final_images_indices].view(NUM_FINAL_SAMPLES, 1, *image_shape) # functional.adjust_hue(tensor1, 1/3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "e7277e9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(\n",
    "    f\"target_data_init_final_plot\",\n",
    "    torch.cat([tensor1, tensor4, tensor2, tensor3], dim=1).view(-1, *image_shape),\n",
    "    normalize=True,\n",
    "    nrow=4,\n",
    "    clamp=True,\n",
    "    step=FROM_ITERATION + 1,\n",
    "    save_dir=Path(\"./out_data\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e1203f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_tensor = torch.cat([tensor1, tensor4, tensor2, tensor3], dim=1) #.view(-1, *image_shape)\n",
    "normalize_transorm = Normalize(mean=[0.5], std=[0.5])\n",
    "# final_tensor = tensor2image(final_tensor)\n",
    "# final_tensor = tensor2image(normalize_transorm(final_tensor))\n",
    "# final_tensor = normalize_transorm(final_tensor) \n",
    "final_tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "4d212566",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the grid for plotting\n",
    "fig, axes = plt.subplots(10, 4, figsize=(6, 15))  # (10 rows, 4 columns)\n",
    "\n",
    "# Adjust space between columns\n",
    "plt.subplots_adjust(wspace=0.25, hspace=0.25)\n",
    "\n",
    "# Column titles\n",
    "column_titles = [\n",
    "    r\"$x_1 \\sim \\pi_{x}^*$\",\n",
    "    r\"$y_2 \\sim \\pi^*(\\cdot \\vert x_2)$\",\n",
    "    r\"$y_1 \\sim \\pi^\\theta_{10}(\\cdot \\vert x_1)$\",\n",
    "    r\"$y_1 \\sim \\pi^\\theta_{200}(\\cdot \\vert x_1)$\",\n",
    "]\n",
    "\n",
    "# Set the titles for the columns\n",
    "for col in range(4):\n",
    "    axes[0, col].set_title(column_titles[col], fontsize=14)\n",
    "\n",
    "# Loop through each row\n",
    "for i in range(10):\n",
    "    for col in range(4):\n",
    "        ax = axes[i, col]  # Select the subplot for the i-th row and col-th column\n",
    "        ax.imshow(final_tensor[i, col].permute(1, 2, 0).cpu().numpy())  # Convert (C, H, W) to (H, W, C)\n",
    "        ax.axis(\"off\")  # Hide axis\n",
    "\n",
    "# Tight layout to make sure there is no overlap\n",
    "# plt.tight_layout()\n",
    "# plt.savefig(\"./out_data/EMB_mnist2to3_vertical.pdf\", bbox_inches='tight')\n",
    "plt.savefig(\"./out_data/EMB_mnist2to3_vertical.png\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "6a51151a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Row titles\n",
    "row_titles = [\n",
    "    r\"$x_1 \\sim \\pi_{x}^*$\",\n",
    "    r\"$y_2 \\sim \\pi^*(\\cdot \\vert x_2)$\",\n",
    "    r\"$y_1 \\sim \\pi^\\theta_{10}(\\cdot \\vert x_1)$\",\n",
    "    r\"$y_1 \\sim \\pi^\\theta_{200}(\\cdot \\vert x_1)$\",\n",
    "]\n",
    "\n",
    "# Create a figure with an extra column for titles\n",
    "fig, axes = plt.subplots(4, 11, figsize=(16, 5))  # 11 columns now (10 for data + 1 for titles)\n",
    "\n",
    "# Adjust spacing (important to make room for titles)\n",
    "plt.subplots_adjust(wspace=0.05, hspace=0.05, left=0.1) #Added more adjustment on the left\n",
    "\n",
    "for row in range(4):\n",
    "    # Get the \"dummy\" subplot for the title\n",
    "    title_ax = axes[row, 0]\n",
    "\n",
    "    # Set the title using text() for more control\n",
    "    title_ax.text(0.5, 0.5, row_titles[row], fontsize=14, ha='center', va='center', transform=title_ax.transAxes)\n",
    "\n",
    "    # Turn off axis for the title subplot\n",
    "    title_ax.axis(\"off\")\n",
    "\n",
    "    # Plot the data in the remaining subplots\n",
    "    for col in range(10):\n",
    "        ax = axes[row, col + 1]  # Shift by 1 to account for the title column\n",
    "        if final_tensor.shape[2] == 1:\n",
    "            ax.imshow(final_tensor[col, row, 0].cpu().numpy(), cmap='gray')\n",
    "        else:\n",
    "            ax.imshow(final_tensor[col, row].permute(1, 2, 0).cpu().numpy())\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.savefig(\"./out_data/EMB_mnist2to3_horizontal.pdf\", bbox_inches='tight')\n",
    "plt.savefig(\"./out_data/EMB_mnist2to3_horizontal.png\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "f4dc4f49",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(\n",
    "    f\"init\",\n",
    "    x_200,\n",
    "    normalize=True,\n",
    "    clamp=True,\n",
    "    step=FROM_ITERATION + 1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d6914d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(\n",
    "    f\"init\",\n",
    "    toy_p_theta,\n",
    "    normalize=True,\n",
    "    clamp=True,\n",
    "    step=FROM_ITERATION + 1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "8853bae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "normalize = transforms.Normalize(mean=[0.5], std=[0.5])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "d12ad76c",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_image = normalize(y_p_theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "1a591b8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_image.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "42f28ca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(\n",
    "    f\"init\",\n",
    "    y_p_theta,\n",
    "    normalize=True,\n",
    "    clamp=True,\n",
    "    step=FROM_ITERATION + 1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "aaf36780",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(\n",
    "    f\"init\",\n",
    "    x_s_t_0,\n",
    "    normalize=True,\n",
    "    step=FROM_ITERATION + 1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f22d646",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not EVAL:\n",
    "    print(\"Training has started.\")\n",
    "    for i in range(FROM_ITERATION, config[\"num_train_iters\"]):\n",
    "        # obtain positive and negative samples\n",
    "        samp_q_y = sample_q_y()\n",
    "        y_s_t_0, x_s_t_0, s_t_0_inds = sample_s_t_0(init_type)\n",
    "        with torch.no_grad():\n",
    "            y_s_t, x_s_t, r_s_t, cost_grad_s_t = sample_s_t(\n",
    "                model,\n",
    "                y_s_t_0,\n",
    "                x_s_t_0,\n",
    "                num_steps=config[\"num_shortrun_steps\"],\n",
    "                s_t_0_inds=s_t_0_inds,\n",
    "                init_type=config[\"shortrun_init\"],\n",
    "            )\n",
    "\n",
    "        # calculate ML computational loss d_s_t (Section 3) for data and shortrun samples\n",
    "        d_s_t = -f(samp_q_y).mean() + f(y_s_t).mean()\n",
    "        # Uncomment also lines in sample_s_t. Maybe scale at the end?\n",
    "        if config[\"epsilon\"] > 0:\n",
    "            # scale loss with the langevin implementation\n",
    "            d_s_t *= 2 / (config[\"epsilon\"] ** 2)\n",
    "        # stochastic gradient ML update for model weights\n",
    "        optimizer.zero_grad()\n",
    "        d_s_t.backward()\n",
    "\n",
    "        q_x_p, q_y_p = sample_pairs()\n",
    "        paired_loss = model.compute_paired_loss(q_x_p, q_y_p)[\"loss\"]\n",
    "        if config[\"epsilon\"] > 0:\n",
    "            # scale loss with the langevin implementation\n",
    "            paired_loss *= 2 / (config[\"epsilon\"] ** 2)\n",
    "\n",
    "        paired_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if EMA_UPDATE:\n",
    "            update_average(model_copy, model, 0.99)\n",
    "\n",
    "        # record diagnostics\n",
    "        d_s_t_record[i] = d_s_t.detach().data\n",
    "        r_s_t_record[i] = r_s_t\n",
    "\n",
    "        # anneal learning rate\n",
    "        for lr_gp in optimizer.param_groups:\n",
    "            lr_gp[\"lr\"] = max(config[\"lr_min\"], lr_gp[\"lr\"] * config[\"lr_decay\"])\n",
    "\n",
    "        # update wandb data\n",
    "        if USE_WANDB:\n",
    "            res_dict = {\n",
    "                \"d_s_t\": d_s_t.detach().data,\n",
    "                \"r_s_t\": r_s_t,\n",
    "                \"cost\": paired_loss.detach().data,\n",
    "                \"cost_grad_s_t\": cost_grad_s_t,\n",
    "            }\n",
    "            wandb.log({\"train\": res_dict}, step=i)\n",
    "\n",
    "        # print and save learning info\n",
    "        if (i + 1) == 1 or (i + 1) % config[\"log_freq\"] == 0:\n",
    "            print(\n",
    "                \"{:>6d}   d_s_t={:>14.9f}   r_s_t={:>14.9f}    cost_grad_s_t={:>14.9f}\".format(\n",
    "                    i + 1, d_s_t.detach().data, r_s_t, cost_grad_s_t\n",
    "                )\n",
    "            )\n",
    "            # visualize synthesized images\n",
    "            if EMA_UPDATE:\n",
    "                with torch.no_grad():\n",
    "                    y_s_t_0, x_s_t_0, s_t_0_inds = sample_s_t_0(init_type)\n",
    "                    y_s_t, x_s_t, r_s_t, cost_grad_s_t = sample_s_t(\n",
    "                        model_copy,\n",
    "                        y_s_t_0,\n",
    "                        x_s_t_0,\n",
    "                        num_steps=config[\"num_shortrun_steps\"],\n",
    "                        s_t_0_inds=s_t_0_inds,\n",
    "                        init_type=config[\"shortrun_init\"],\n",
    "                    )\n",
    "            pbuff_dict = plot_images(\n",
    "                f\"pairs x->y, pbuff init\",\n",
    "                x_s_t,\n",
    "                target_tensor=y_s_t,\n",
    "                step=i + 1,\n",
    "                use_wandb=USE_WANDB,\n",
    "                save_dir=Path(EXP_DIR) / \"shortrun\",\n",
    "            )\n",
    "            # WARNING: work only for unet potential\n",
    "            cost_dict = plot_images(\n",
    "                f\"pairs x->g(x)\",\n",
    "                x_s_t,\n",
    "                target_tensor=model.cost.net(x_s_t),\n",
    "                step=i + 1,\n",
    "                use_wandb=USE_WANDB,\n",
    "                save_dir=Path(EXP_DIR) / \"shortrun\",\n",
    "            )\n",
    "            wandb.log(pbuff_dict | cost_dict, step=i)\n",
    "\n",
    "            if config[\"shortrun_init\"] == \"persistent\":\n",
    "                plot_images(\n",
    "                    \"Ys from pbuff\",\n",
    "                    s_t_0[0 : config[\"batch_size\"]],\n",
    "                    step=i,\n",
    "                    save_dir=EXP_DIR + \"shortrun/\" + \"y_s_t_0_{:>06d}.png\".format(i + 1),\n",
    "                    use_wandb=USE_WANDB,\n",
    "                )\n",
    "            # save network weights\n",
    "            torch.save(model.state_dict(), EXP_DIR + \"checkpoints/\" + \"model_{:>06d}.pth\".format(i + 1))\n",
    "            # save optimizer weights\n",
    "            torch.save(optimizer.state_dict(), EXP_DIR + \"checkpoints/\" + \"optim_{:>06d}.pth\".format(i + 1))\n",
    "            # plot diagnostics for energy difference d_s_t and gradient magnitude r_t\n",
    "            # if (i + 1) > 1:\n",
    "            #     plot_diagnostics(i, d_s_t_record, r_s_t_record, EXP_DIR + \"plots/\")\n",
    "            # torch.cuda.empty_cache()\n",
    "\n",
    "        # sample longrun chains to diagnose model steady-state\n",
    "        if config[\"log_longrun\"] and (i + 1) % config[\"log_longrun_freq\"] == 0:\n",
    "            print(\n",
    "                \"{:>6d}   Generating long-run samples. (L={:>6d} MCMC steps)\".format(\n",
    "                    i + 1, config[\"num_longrun_steps\"]\n",
    "                )\n",
    "            )\n",
    "            for init_type in [\n",
    "                config[\"shortrun_init\"]\n",
    "            ]:  # [\"DOT\", \"persistent\", \"uniform\", \"source_data\", \"target_data\"]:\n",
    "                with torch.no_grad():\n",
    "                    y_s_t_0, x_s_t_0, s_t_0_inds = sample_s_t_0(init_type)\n",
    "                    _model = model_copy if EMA_UPDATE else model\n",
    "                    y_p_theta, x_p_theta, _, _ = sample_s_t(\n",
    "                        _model,\n",
    "                        y_s_t_0,\n",
    "                        x_s_t_0,\n",
    "                        num_steps=config[\"num_longrun_steps\"],\n",
    "                        init_type=init_type,\n",
    "                        s_t_0_inds=s_t_0_inds,\n",
    "                        update_s_t_0=False,\n",
    "                    )\n",
    "                longrun_dict = plot_images(\n",
    "                    f\"pairs x->y, longrun, {init_type} init\",\n",
    "                    x_p_theta,\n",
    "                    target_tensor=y_p_theta,\n",
    "                    step=i + 1,\n",
    "                    use_wandb=USE_WANDB,\n",
    "                    save_dir=Path(EXP_DIR) / \"longrun\",\n",
    "                )\n",
    "                wandb.log(longrun_dict, step=i)\n",
    "                print(\"{:>6d}   Long-run samples for init {} saved.\".format(i + 1, init_type))\n",
    "\n",
    "        # WARNING: To reduce memory leakage\n",
    "        # del samp_q_y, y_s_t, x_s_t, r_s_t\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f036daa0",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d8508840",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "Q_X_UNPAIRED_SAMPLES = 5958\n",
    "R_Y_UNPAIRED_SAMPLES = 6131\n",
    "P_XY_PAIRED_SAMPLES = 2000\n",
    "IMG_SIZE = 32\n",
    "\n",
    "\n",
    "LR_PAIRED = 1e-5\n",
    "LR_UNPAIRED = 3e-5\n",
    "SAMPLING_NUM_ITER = 500\n",
    "MAX_STEPS = 100000\n",
    "COST_FUNCTION = \"MLP\"\n",
    "PAIRED_BATCH_SIZE = 256\n",
    "UNPAIRED_BATCH_SIZE = 256\n",
    "SEED = 30"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c47f945",
   "metadata": {},
   "source": [
    "## 3. Create data and samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f0958ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.dataset.colored_mnist import (\n",
    "    apply_random_color,\n",
    "    download_digit_images,\n",
    "    get_paired_digits,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17f756ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = \"MNIST\"\n",
    "TARGET_DATASET = \"MNIST\"\n",
    "SOURCE_DIGIT = 2\n",
    "TARGET_DIGIT = 3\n",
    "P_XY_PAIRED_SAMPLES = 5000\n",
    "\n",
    "source_images: list[torch.Tensor] = download_digit_images(SOURCE_DATASET, SOURCE_DIGIT, 1000)\n",
    "target_images: list[torch.Tensor] = download_digit_images(TARGET_DATASET, TARGET_DIGIT, 2000)\n",
    "\n",
    "q_x_paired, q_y_paired = get_paired_digits(\n",
    "    source_images, target_images, P_XY_PAIRED_SAMPLES, hue_offset=120, device=device\n",
    ")\n",
    "\n",
    "q_x = torch.stack([apply_random_color(digit, 360 * torch.rand(1)) for digit in source_images]).to(device)\n",
    "q_y = torch.stack([apply_random_color(digit, 360 * torch.rand(1)) for digit in target_images]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d2afc216",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision as tv\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision.transforms import functional\n",
    "from mnist2to3.utils import plot_im_pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c3b01841",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = q_x[:10]\n",
    "y = q_y[:10]\n",
    "# to_draw = torch.clamp(torch.cat([x.unsqueeze(1), y.unsqueeze(1)], 1).view(-1, *(3, 32, 32)), -1.0, 1.0)\n",
    "to_draw = torch.cat([x.unsqueeze(1), y.unsqueeze(1)], 1).view(-1, *(3, 32, 32))\n",
    "SB_torch_grid = tv.utils.make_grid(to_draw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e46c4868",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict = plot_im_pairs(\"\", x, y, nrow=2, use_wandb=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "98daa3d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_tensor = q_x_paired[0].roll(1, dims=0).permute(1, 2, 0)\n",
    "image_tensor = functional.adjust_hue(q_x_paired[0], 1/3).permute(1, 2, 0)\n",
    "\n",
    "# Plot the image\n",
    "plt.imshow(image_tensor.detach().cpu().numpy())\n",
    "plt.axis('off')  # Turn off axes for better visualization\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b71372b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_tensor = q_y_paired[0].permute(1, 2, 0)\n",
    "\n",
    "# Plot the image\n",
    "plt.imshow(image_tensor.detach().cpu().numpy())\n",
    "plt.axis('off')  # Turn off axes for better visualization\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f61b8693",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_train_sampler = get_paired_sampler(\n",
    "    paired_source_samples, paired_target_samples, train_config.paired_batch_size, dataset_config.P_XY_paired, device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "afe419b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_config.Q_X_unpaired > 0:\n",
    "    usd_sampler = DatasetSampler(unpaired_source_samples, device=device) # usd - unpaired source data\n",
    "else:\n",
    "    usd_sampler = DatasetSampler(paired_source_samples, device=device)\n",
    "\n",
    "if dataset_config.R_Y_unpaired > 0:\n",
    "    utd_sampler = DatasetSampler(unpaired_target_samples, device=device) # utd - unpaired target data\n",
    "else:\n",
    "    utd_sampler = DatasetSampler(paired_target_samples, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd8d737",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "34d9128b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.costs.convolutional import NonlocalCost, VanillaCost, UNetCost\n",
    "from src.potentials.vanilla import NonlocalPotential, VanillaPotential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "bb931bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "potential = VanillaPotential(n_f=IMG_SIZE) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "1e6c3861",
   "metadata": {},
   "outputs": [],
   "source": [
    "cost = UNetCost(n_c=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "09e4b989",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EGEOT(potential, cost, sample_buffer_instance, model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "fa67a50c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For EMA update\n",
    "if train_config.ema_update:\n",
    "    model_copy = EGEOT(potential, cost, sample_buffer_instance, model_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6d84e4b",
   "metadata": {},
   "source": [
    "## 5. Optimizers initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "a876937b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.dataset.colored_mnist import (\n",
    "    apply_random_color,\n",
    "    download_digit_images,\n",
    "    get_paired_digits,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "0e7bd76e",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = \"MNIST\"\n",
    "TARGET_DATASET = \"MNIST\"\n",
    "SOURCE_DIGIT = 2\n",
    "TARGET_DIGIT = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "bd9a667d",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_images[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "092ec2d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "255cea84",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(32),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize([0.5], [0.5]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "data = datasets.MNIST(root=\"./data\", train=True, transform=transform, download=True)\n",
    "\n",
    "\n",
    "indices = [i for i, label in enumerate(data.targets) if label == 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "d1c63842",
   "metadata": {},
   "outputs": [],
   "source": [
    "print([data[i][0] for i in indices[:2]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "f87a8b8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_images[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "d6e2e4c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_images: list[torch.Tensor] = download_digit_images(SOURCE_DATASET, SOURCE_DIGIT)\n",
    "target_images: list[torch.Tensor] = download_digit_images(TARGET_DATASET, TARGET_DIGIT)\n",
    "\n",
    "p_sampler = get_paired_digits(source_images, target_images, P_XY_PAIRED_SAMPLES, device=device)\n",
    "\n",
    "q_x = torch.stack([apply_random_color(digit, 360 * torch.rand(1, device=digit.device)) for digit in source_images]).to(device)\n",
    "q_y = torch.stack([apply_random_color(digit, 360 * torch.rand(1, device=digit.device)) for digit in target_images]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "ad917a2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_x.shape, q_y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "fb1c93cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_y[:64].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "6eff7fdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.potential.n_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "458651ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.potential.net(q_y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "7e7410a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.potential.grad_y(q_y[:64])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b4c4ee69",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_unpaired = torch.optim.Adam(model.potential.parameters(), **opt_unpaired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "79c060ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_paired = torch.optim.Adam(model.cost.parameters(), **opt_paired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8494eef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: refactor this config\n",
    "EXP_NAME = (\n",
    "    \"EgEOT_ColoredMNIST_\"\n",
    "    + f\"COST_FUNCTION_{COST_FUNCTION}_\"\n",
    "    + f\"P_XY_PAIRED_{dataset_config.P_XY_paired}_\"\n",
    "    + f\"Q_X_UNPAIRED_{dataset_config.Q_X_unpaired}_\"\n",
    "    + f\"R_Y_UNPAIRED_{dataset_config.R_Y_unpaired}_\"\n",
    "    + f\"LR_PAIRED_{opt_paired_config.lr}_\"\n",
    "    + f\"LR_UNPAIRED_{opt_unpaired_config.lr}_\"\n",
    "    + f\"SAMPLING_STEPS_{model_config.sampling.num_iterations}_\"\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    COST_FUNCTION=COST_FUNCTION,\n",
    "    X_DIM=dataset_config.x_dim,\n",
    "    Y_DIM=dataset_config.y_dim,\n",
    "    D_LR_PAIRED=opt_paired_config.lr,\n",
    "    D_LR_UNPAIRED=opt_unpaired_config.lr,\n",
    "    BATCH_SIZE=train_config.unpaired_batch_size,\n",
    "    P_XY_PAIRED_SAMPLES=dataset_config.P_XY_paired,\n",
    "    Q_X_UNPAIRED_SAMPLES=dataset_config.Q_X_unpaired,\n",
    "    R_Y_UNPAIRED_SAMPLES=dataset_config.R_Y_unpaired,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b726194c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train_config.steps_from > 0:\n",
    "    D_opt_unpaired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{train_config.steps_from}.pt\")))\n",
    "    D_opt_paired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_paired_{train_config.steps_from}.pt\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71524ecb",
   "metadata": {},
   "source": [
    "## 6. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adf4c2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(0, MAX_STEPS)):\n",
    "    # training loop\n",
    "    D_opt_unpaired.zero_grad()\n",
    "\n",
    "    X = usd_sampler.sample(train_config.unpaired_batch_size)\n",
    "    Y = utd_sampler.sample(train_config.unpaired_batch_size)\n",
    "\n",
    "    output_unpaired = model.compute_unpaired_loss(X, Y, compute_stats=True)\n",
    "    D_loss_unpaired = output_unpaired[\"loss\"]\n",
    "\n",
    "    wandb.log({f\"Unpaired: Loss\": D_loss_unpaired.item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: \\int f(y)\": output_unpaired[\"int_potential\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: \\int\\log Z\": output_unpaired[\"int_log_Z\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: -E(x, y)\": output_unpaired[\"neg_energy_t\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: c(x, y)\": output_unpaired[\"cost_t\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: f(y)\": output_unpaired[\"potential_t\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: noise\": output_unpaired[\"noise\"].item()}, step=step)\n",
    "\n",
    "    D_opt_paired.zero_grad()\n",
    "    X_paired, Y_paired = pd_train_sampler.sample(train_config.paired_batch_size)\n",
    "\n",
    "    output_paired = model.compute_paired_loss(X_paired, Y_paired)\n",
    "    D_loss_paired = output_paired[\"loss\"]\n",
    "\n",
    "    wandb.log({f\"Paired: Loss\": D_loss_paired.item()}, step=step)\n",
    "\n",
    "    D_loss = D_loss_unpaired + D_loss_paired\n",
    "    D_loss.backward()\n",
    "    D_opt_paired.step()\n",
    "    D_opt_unpaired.step()\n",
    "\n",
    "    if train_config.ema_update:\n",
    "        update_average(model_copy, model, 0.99)\n",
    "        model = model_copy\n",
    "    else:\n",
    "        model = model\n",
    "\n",
    "    wandb.log({f\"Loss\": D_loss}, step=step)\n",
    "\n",
    "    if step % train_config.plot_every == 0:\n",
    "        distr_dict = plot_samples(model, unpaired_source_samples, log=True)\n",
    "        wandb.log(distr_dict)\n",
    "        torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"model_{step}.pt\"))\n",
    "\n",
    "torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{train_config.steps_to}.pt\"))\n",
    "torch.save(D_opt_paired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_paired_{train_config.steps_to}.pt\"))\n",
    "torch.save(D_opt_unpaired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{train_config.steps_to}.pt\"))\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4437648d",
   "metadata": {},
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b8b951b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_samples(model, unpaired_source_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "586e9bce",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "text",
   "language": "python",
   "name": "text"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
