{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7182ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from infrastructure.classes import Experiment, TrainParameters\n",
    "from utils.plot import plot_quantile_levels_from_dataset\n",
    "from infrastructure.training import train\n",
    "import torch\n",
    "\n",
    "experiment = Experiment(\n",
    "    tensor_parameters=dict(dtype=torch.float64, device=torch.device(\"cpu\")),\n",
    "    dataset_name=\"picnn_fnlvqr_banana\",\n",
    "    dataset_number_of_points=10**4,\n",
    "    dataset_parameters=dict(number_of_features=10, number_of_classes=2),\n",
    "    dataloader_parameters=dict(batch_size=124, shuffle=True),\n",
    "    pushforward_operator_name=\"entropic_neural_quantile_regression\",\n",
    "    pushforward_operator_parameters=dict(\n",
    "        feature_dimension=1,\n",
    "        response_dimension=2,\n",
    "        hidden_dimension=8,\n",
    "        number_of_hidden_layers=4,\n",
    "        epsilon=1e-3,\n",
    "        amount_of_samples_to_estimate_psi=512,\n",
    "    ),\n",
    "    train_parameters=TrainParameters(\n",
    "        number_of_epochs_to_train=100,\n",
    "        verbose=True,\n",
    "        optimizer_parameters=dict(lr=0.01),\n",
    "        scheduler_parameters=dict(eta_min=0),\n",
    "        warmup_iterations=30\n",
    "    )\n",
    ")\n",
    "\n",
    "model = train(experiment)\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6def3489",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from datasets import PICNN_FNLVQR_Banana\n",
    "\n",
    "dataset = PICNN_FNLVQR_Banana(tensor_parameters=experiment.tensor_parameters)\n",
    "plot_quantile_levels_from_dataset(\n",
    "    model=model,\n",
    "    dataset=dataset,\n",
    "    conditional_value=dataset.sample_covariates(1),\n",
    "    number_of_quantile_levels=10,\n",
    "    tensor_parameters=experiment.tensor_parameters\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
