{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f3bd7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd972eac",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6261775b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.distributions as TD\n",
    "from tqdm import tqdm\n",
    "\n",
    "import wandb\n",
    "from src.costs.lse import MLPLSECost\n",
    "from src.costs.mlp_based import MLPCost, MLPL2Cost\n",
    "from src.models.energy_based import EGEOT\n",
    "from src.plotting.distributions import plot_swiss_roll\n",
    "from src.plotting.parameters import plot_B_parameters\n",
    "from src.potentials.mlp_based import MLPPotential\n",
    "from src.samplers.energy_based.sample_buffer import SampleBufferEgEOT\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "from src.utils.train import compute_loss, update_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5ea4743c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "53e79cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_device(device)\n",
    "dtype = torch.float64\n",
    "torch.torch.set_default_dtype(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f036daa0",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6067778b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs.energy_based.cost import MLPCostConfig, MLPLSECostConfig, MLPL2CostConfig\n",
    "from configs.energy_based.dataset import DatasetConfig, MiniBatchConfig\n",
    "from configs.energy_based.model import EBMConfig\n",
    "from configs.energy_based.optimizer import OptPairedConfig, OptUnpairedConfig\n",
    "from configs.energy_based.potential import PotentialConfig\n",
    "from configs.energy_based.sampling import LangevinConfig\n",
    "from configs.energy_based.train import TrainConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d8508840",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "Q_X_UNPAIRED_SAMPLES = 1024\n",
    "R_Y_UNPAIRED_SAMPLES = 1024\n",
    "P_XY_PAIRED_SAMPLES = 128\n",
    "LR_PAIRED = 5e-4\n",
    "LR_UNPAIRED = 2e-4\n",
    "SAMPLING_NUM_ITER = 100\n",
    "MAX_STEPS = 1000\n",
    "COST_FUNCTION = \"MLP\"\n",
    "PAIRED_BATCH_SIZE = 128\n",
    "UNPAIRED_BATCH_SIZE = 128\n",
    "\n",
    "# For MLPCost\n",
    "HIDDEN_LAYERS = [128, 128]\n",
    "\n",
    "# For MLPLSECost\n",
    "M_POTENTIALS = 2\n",
    "LOG_V_M_HIDDEN_CHANNELS = [128, 128]\n",
    "B_M_HIDDEN_CHANNELS = [128, 128]\n",
    "\n",
    "# For MLPL2Cost\n",
    "X_HIDDEN_LAYERS: list[int] = [128, 128]\n",
    "Y_HIDDEN_LAYERS: list[int] = [128, 128]\n",
    "\n",
    "# For MLPPotential\n",
    "POTENTIAL_HIDDEN_LAYERS = [256, 256, 256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4f8bf62e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = DatasetConfig(\n",
    "    P_XY_paired=P_XY_PAIRED_SAMPLES, Q_X_unpaired=Q_X_UNPAIRED_SAMPLES, R_Y_unpaired=R_Y_UNPAIRED_SAMPLES\n",
    ")\n",
    "minibatch_config = MiniBatchConfig()\n",
    "\n",
    "potential_config = PotentialConfig(hidden_layers=POTENTIAL_HIDDEN_LAYERS)\n",
    "if COST_FUNCTION == \"MLP\":\n",
    "    cost_config = MLPCostConfig(hidden_layers=HIDDEN_LAYERS)\n",
    "    EXP_META_INFO = f\"HIDDEN_LAYERS_{HIDDEN_LAYERS}_\"\n",
    "elif COST_FUNCTION == \"MLPLSE\":\n",
    "    cost_config = MLPLSECostConfig(\n",
    "        m_potentials=M_POTENTIALS,\n",
    "        log_v_m_hidden_channels=LOG_V_M_HIDDEN_CHANNELS,\n",
    "        b_m_hidden_channels=B_M_HIDDEN_CHANNELS,\n",
    "    )\n",
    "    EXP_META_INFO = (\n",
    "        f\"M_POTENTIALS_{M_POTENTIALS}_\"\n",
    "        + f\"LOG_V_M_HIDDEN_CHANNELS_{LOG_V_M_HIDDEN_CHANNELS}_\"\n",
    "        + f\"B_M_HIDDEN_CHANNELS_{B_M_HIDDEN_CHANNELS}_\"\n",
    "    )\n",
    "elif COST_FUNCTION == \"MLPL2\":\n",
    "    cost_config = MLPL2CostConfig(\n",
    "        x_hidden_layers=X_HIDDEN_LAYERS,\n",
    "        y_hidden_layers=Y_HIDDEN_LAYERS,\n",
    "    )\n",
    "    EXP_META_INFO = f\"X_HIDDEN_LAYERS_{X_HIDDEN_LAYERS}_\" + f\"Y_HIDDEN_LAYERS_{Y_HIDDEN_LAYERS}_\"\n",
    "else:\n",
    "    raise ValueError(f\"Unknown cost function: {COST_FUNCTION}!\")\n",
    "model_config = EBMConfig(sampling=LangevinConfig(num_iterations=SAMPLING_NUM_ITER))\n",
    "\n",
    "opt_unpaired_config = OptUnpairedConfig(lr=LR_UNPAIRED)\n",
    "opt_paired_config = OptPairedConfig(lr=LR_PAIRED)\n",
    "\n",
    "train_config = TrainConfig(\n",
    "    steps_to=MAX_STEPS, paired_batch_size=PAIRED_BATCH_SIZE, unpaired_batch_size=UNPAIRED_BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "339530f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(train_config.seed)\n",
    "np.random.seed(train_config.seed)\n",
    "random.seed(train_config.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c47f945",
   "metadata": {},
   "source": [
    "## 3. Create data and samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5973a562",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = StandardNormalSampler(dim=dataset_config.x_dim, device=device)\n",
    "Y_sampler = SwissRollSampler(dim=dataset_config.y_dim, device=device, dtype=dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fb1d2f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "otp_sampler = OTPlanSampler(**minibatch_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d235fb71",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"checkpoints/Tensors\"\n",
    "file_postfix = f\"{minibatch_config.cost_function}_{dataset_config.P_XY_paired}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "86defff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_paired_train, Y_paired_train, X_paired_test, Y_paired_test = generate_paired_data(\n",
    "    X_sampler, Y_sampler, otp_sampler, dataset_config.P_XY_paired, \"./checkpoints/Tensors\", file_postfix, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f61b8693",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_train_sampler = get_paired_sampler(\n",
    "    X_paired_train, Y_paired_train, train_config.paired_batch_size, dataset_config.P_XY_paired, device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "09ffa079",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_unpaired_test = X_sampler.sample(dataset_config.P_XY_paired)\n",
    "Y_unpaired_test = Y_sampler.sample(dataset_config.P_XY_paired)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "afe419b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_config.Q_X_unpaired > 0:\n",
    "    source_data = X_sampler.sample(dataset_config.Q_X_unpaired)\n",
    "    usd_sampler = DatasetSampler(source_data, device=device) # usd - unpaired source data\n",
    "else:\n",
    "    usd_sampler = DatasetSampler(X_paired_train, device=device)\n",
    "\n",
    "if dataset_config.R_Y_unpaired > 0:\n",
    "    target_data = Y_sampler.sample(dataset_config.R_Y_unpaired)\n",
    "    utd_sampler = DatasetSampler(target_data, device=device) # utd - unpaired target data\n",
    "else:\n",
    "    utd_sampler = DatasetSampler(Y_paired_train, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd8d737",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bb931bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "potential = MLPPotential(**potential_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4f812aa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if COST_FUNCTION == \"MLP\":\n",
    "    cost = MLPCost(**cost_config.model_dump())\n",
    "elif COST_FUNCTION == \"MLPLSE\":\n",
    "    cost = MLPLSECost(**cost_config.model_dump())\n",
    "elif COST_FUNCTION == \"MLPL2\":\n",
    "    cost = MLPL2Cost(**cost_config.model_dump())\n",
    "else:\n",
    "    raise ValueError(f\"Unknown cost function: {COST_FUNCTION}!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "92991439",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: add to config\n",
    "BASIC_NOISE_VAR = 1.0\n",
    "P_SAMPLE_BUFFER_REPLAY = 0.95\n",
    "SAMPLE_BUFFER_SAMPLES = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "423f3579",
   "metadata": {},
   "outputs": [],
   "source": [
    "basic_noise_gen = TD.Normal(torch.tensor([0.0, 0.0]).to(device), torch.tensor([1.0, 1.0]).to(device) * BASIC_NOISE_VAR)\n",
    "\n",
    "sample_buffer_instance = SampleBufferEgEOT(\n",
    "    basic_noise_gen, p=P_SAMPLE_BUFFER_REPLAY, max_samples=SAMPLE_BUFFER_SAMPLES, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "09e4b989",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EGEOT(potential, cost, sample_buffer_instance, model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fa67a50c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For EMA update\n",
    "if train_config.ema_update:\n",
    "    model_copy = EGEOT(potential, cost, sample_buffer_instance, model_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6d84e4b",
   "metadata": {},
   "source": [
    "## 5. Optimizers initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b4c4ee69",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_unpaired = torch.optim.Adam(model.potential.parameters(), **opt_unpaired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "79c060ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_opt_paired = torch.optim.Adam(model.cost.parameters(), **opt_paired_config.model_dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8494eef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: refactor this config\n",
    "EXP_NAME = (\n",
    "    \"EgEOT_Swiss_Roll_\"\n",
    "    + f\"COST_FUNCTION_{COST_FUNCTION}_\"\n",
    "    + f\"P_XY_PAIRED_{dataset_config.P_XY_paired}_\"\n",
    "    + f\"Q_X_UNPAIRED_{dataset_config.Q_X_unpaired}_\"\n",
    "    + f\"R_Y_UNPAIRED_{dataset_config.R_Y_unpaired}_\"\n",
    "    + f\"LR_PAIRED_{opt_paired_config.lr}_\"\n",
    "    + f\"LR_UNPAIRED_{opt_unpaired_config.lr}_\"\n",
    "    + f\"MINIBATCH_COST_{minibatch_config.cost_function}_\"\n",
    "    + f\"SAMPLING_STEPS_{model_config.sampling.num_iterations}_\"\n",
    "    + f\"POTENTIAL_HIDDEN_LAYERS_{potential_config.hidden_layers}\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    COST_FUNCTION=COST_FUNCTION,\n",
    "    X_DIM=dataset_config.x_dim,\n",
    "    Y_DIM=dataset_config.y_dim,\n",
    "    D_LR_PAIRED=opt_paired_config.lr,\n",
    "    D_LR_UNPAIRED=opt_unpaired_config.lr,\n",
    "    BATCH_SIZE=train_config.unpaired_batch_size,\n",
    "    P_XY_PAIRED_SAMPLES=dataset_config.P_XY_paired,\n",
    "    Q_X_UNPAIRED_SAMPLES=dataset_config.Q_X_unpaired,\n",
    "    R_Y_UNPAIRED_SAMPLES=dataset_config.R_Y_unpaired,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b726194c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train_config.steps_from > 0:\n",
    "    D_opt_unpaired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{train_config.steps_from}.pt\")))\n",
    "    D_opt_paired.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f\"D_opt_paired_{train_config.steps_from}.pt\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71524ecb",
   "metadata": {},
   "source": [
    "## 6. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "f65e6586",
   "metadata": {},
   "outputs": [],
   "source": [
    "starting_points = torch.tensor([[-2.0, 0.0], [0.0, 0.0], [0.0, -2.0]])\n",
    "num_ending_points = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e32f4804",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_starting_points_paired = 5\n",
    "indices = random.choices(range(dataset_config.P_XY_paired), k=num_starting_points_paired)\n",
    "starting_points_paired = X_paired_train[indices]\n",
    "ending_points_paired = Y_paired_train[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c512a511",
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_Y_points = get_GT_points(X_sampler, Y_sampler, otp_sampler, starting_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adf4c2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(train_config.steps_from, train_config.steps_to)):\n",
    "    # training loop\n",
    "    D_opt_unpaired.zero_grad()\n",
    "\n",
    "    X = usd_sampler.sample(train_config.unpaired_batch_size)\n",
    "    Y = utd_sampler.sample(train_config.unpaired_batch_size)\n",
    "\n",
    "    output_unpaired = model.compute_unpaired_loss(X, Y, compute_stats=True)\n",
    "    D_loss_unpaired = output_unpaired[\"loss\"]\n",
    "\n",
    "    wandb.log({f\"Unpaired: Loss\": D_loss_unpaired.item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: \\int f(y)\": output_unpaired[\"int_potential\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: \\int\\log Z\": output_unpaired[\"int_log_Z\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: -E(x, y)\": output_unpaired[\"neg_energy_t\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: c(x, y)\": output_unpaired[\"cost_t\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: f(y)\": output_unpaired[\"potential_t\"].item()}, step=step)\n",
    "    wandb.log({f\"Unpaired: noise\": output_unpaired[\"noise\"].item()}, step=step)\n",
    "\n",
    "    D_opt_paired.zero_grad()\n",
    "    X_paired, Y_paired = pd_train_sampler.sample(train_config.paired_batch_size)\n",
    "\n",
    "    output_paired = model.compute_paired_loss(X_paired, Y_paired, compute_stats=True)\n",
    "    D_loss_paired = output_paired[\"loss\"]\n",
    "\n",
    "    wandb.log({f\"Paired: Loss\": D_loss_paired.item()}, step=step)\n",
    "    # wandb.log({f\"Paired: -E(x, y)\": output_paired[\"neg_energy_t\"].item()}, step=step)\n",
    "    # wandb.log({f\"Paired: c(x, y)\": output_paired[\"cost_t\"].item()}, step=step)\n",
    "    # wandb.log({f\"Paired: f(y)\": output_paired[\"potential_t\"].item()}, step=step)\n",
    "    # wandb.log({f\"Paired: noise\": output_paired[\"noise\"].item()}, step=step)\n",
    "\n",
    "    D_loss = D_loss_unpaired + D_loss_paired\n",
    "    D_loss.backward()\n",
    "    D_opt_paired.step()\n",
    "    D_opt_unpaired.step()\n",
    "\n",
    "    if train_config.ema_update:\n",
    "        update_average(model_copy, model, 0.99)\n",
    "        model = model_copy\n",
    "    else:\n",
    "        model = model\n",
    "\n",
    "    wandb.log({f\"Loss\": D_loss}, step=step)\n",
    "    wandb.log(\n",
    "        {f\"Train paired loss\": compute_loss(model, X_paired_train, Y_paired_train, X_paired_train, Y_paired_train)},\n",
    "        step=step,\n",
    "    )\n",
    "    wandb.log(\n",
    "        {f\"Test paired loss\": compute_loss(model, X_paired_test, Y_paired_test, X_paired_test, Y_paired_test)},\n",
    "        step=step,\n",
    "    )\n",
    "    wandb.log(\n",
    "        {f\"Test unpaired loss\": compute_loss(model, X_unpaired_test, Y_unpaired_test, X_paired_test, Y_paired_test)},\n",
    "        step=step,\n",
    "    )\n",
    "\n",
    "    if step % train_config.plot_every == 0:\n",
    "        distr_dict = plot_swiss_roll(\n",
    "            {\"EBM\": model},\n",
    "            X_sampler,\n",
    "            Y_sampler,\n",
    "            X_paired_train,\n",
    "            Y_paired_train,\n",
    "            starting_points,\n",
    "            gt_Y_points,\n",
    "            log=True,\n",
    "        )\n",
    "        if COST_FUNCTION == \"MLPLSE\":\n",
    "            B_dict = B_dict = plot_B_parameters(model.cost, starting_points, log=True)\n",
    "            distr_dict = distr_dict | B_dict\n",
    "        wandb.log(distr_dict)\n",
    "        torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"model_{step}.pt\"))\n",
    "\n",
    "torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, f\"D_{train_config.steps_to}.pt\"))\n",
    "torch.save(D_opt_paired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_paired_{train_config.steps_to}.pt\"))\n",
    "torch.save(D_opt_unpaired.state_dict(), os.path.join(OUTPUT_PATH, f\"D_opt_unpaired_{train_config.steps_to}.pt\"))\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4437648d",
   "metadata": {},
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2e68c54",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_swiss_roll(\n",
    "    {\"EBM\": model},\n",
    "    X_sampler,\n",
    "    Y_sampler,\n",
    "    X_paired_train,\n",
    "    Y_paired_train,\n",
    "    starting_points,\n",
    "    gt_Y_points,\n",
    ") "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "text",
   "language": "python",
   "name": "text"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
