{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f7182ad6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Warm up iteration: 3 Potential loss: 0.036, Amortization loss: 0.007: 100%|██████████| 3/3 [00:08<00:00,  2.69s/it]\n",
      "Epoch: 25, Potential Objective: 0.859, Amortization Objective: 0.124, Potential LR: 0.000000, Amortized LR: 0.001216: 100%|██████████| 25/25 [18:48<00:00, 45.15s/it]\n"
     ]
    }
   ],
   "source": [
    "from infrastructure.classes import Experiment, TrainParameters\n",
    "from infrastructure.training import train\n",
    "import torch\n",
    "\n",
    "experiment = Experiment(\n",
    "    tensor_parameters=dict(dtype=torch.float32, device=torch.device(\"cpu\")),\n",
    "    dataset_name=\"funnel\",\n",
    "    dataset_parameters={\n",
    "        \"target_dimension\":2\n",
    "    },\n",
    "    dataset_number_of_points=10**5,\n",
    "    dataloader_parameters=dict(batch_size=256, shuffle=True),\n",
    "    pushforward_operator_name=\"amortized_neural_quantile_regression\",\n",
    "    pushforward_operator_parameters=dict(\n",
    "        feature_dimension=1,\n",
    "        response_dimension=2,\n",
    "        hidden_dimension=18,\n",
    "        number_of_hidden_layers=8,\n",
    "        potential_to_estimate_with_neural_network=\"u\",\n",
    "    ),\n",
    "    train_parameters=TrainParameters(\n",
    "        number_of_epochs_to_train=25,\n",
    "        verbose=True,\n",
    "        optimizer_parameters=dict(\n",
    "            lr=1e-2,\n",
    "            weight_decay=1e-4\n",
    "        ),\n",
    "        scheduler_parameters=dict(eta_min=0),\n",
    "        warmup_iterations=3,\n",
    "    )\n",
    ")\n",
    "\n",
    "model = train(experiment)\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6def3489",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from datasets import PICNN_FNLVQR_Banana\n",
    "from utils.plot import plot_quantile_levels_from_dataset\n",
    "\n",
    "dataset = PICNN_FNLVQR_Banana(tensor_parameters=experiment.tensor_parameters)\n",
    "plot_quantile_levels_from_dataset(\n",
    "    model=model,\n",
    "    dataset=dataset,\n",
    "    conditional_value=torch.tensor([[1.5]]),\n",
    "    number_of_quantile_levels=10,\n",
    "    tensor_parameters=experiment.tensor_parameters\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
