{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0659c1d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import PICNN_FNLVQR_Banana\n",
    "\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "dataset = PICNN_FNLVQR_Banana(tensor_parameters=tensor_parameters)\n",
    "X, Y = dataset.sample_joint(n_points=1000)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(121, projection='3d')\n",
    "ax.scatter(X[:, 0], Y[:, 0], Y[:, 1])\n",
    "ax.set_title(\"PICNN_FNLVQR_Banana\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "916a9e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import FNLVQR_Banana\n",
    "\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "dataset = FNLVQR_Banana(tensor_parameters=tensor_parameters)\n",
    "X, Y = dataset.sample_joint(n_points=1000)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "ax.scatter(X[:, 0], Y[:, 0], Y[:, 1],)\n",
    "ax.set_title(\"FNLVQR_Banana\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c55e5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import FNLVQR_Glasses\n",
    "\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "dataset = FNLVQR_Glasses(tensor_parameters=tensor_parameters)\n",
    "X, Y = dataset.sample_joint(n_points=1000)\n",
    "\n",
    "plt.scatter(X[:, 0], Y[:, 0])\n",
    "plt.title(\"FNLVQR_Glasses\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b99c6f47",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import PICNN_FNLVQR_Star\n",
    "\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "dataset = PICNN_FNLVQR_Star(tensor_parameters=tensor_parameters, amplitude=1)\n",
    "X, Y = dataset.sample_joint(n_points=5000)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "ax.scatter(Y[:, 0], X[:, 0], Y[:, 1], s=2.5)\n",
    "ax.set_title(\"PICNN_FNLVQR_Star\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf8eeb74",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import FNLVQR_Star\n",
    "\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "dataset = FNLVQR_Star(tensor_parameters=tensor_parameters, amplitude=2.)\n",
    "X, Y = dataset.sample_joint(n_points=5000)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "ax.scatter(Y[:, 0], X[:, 0], Y[:, 1], s=1)\n",
    "ax.set_title(\"FNLVQR_Star\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "970d8fef",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from datasets import FNLVQR_MVN\n",
    "\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "dataset = FNLVQR_MVN(number_of_responses=2, number_of_features=1, tensor_parameters=tensor_parameters)\n",
    "X, Y = dataset.sample_joint(n_points=1000)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "ax.scatter(X[:, 0], Y[:, 0], Y[:, 1])\n",
    "ax.set_title(\"FNLVQR_MVN\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6592668",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from pushforward_operators.neural_quantile_regression.amortized_neural_quantile_regression import AmortizedNeuralQuantileRegression\n",
    "from pushforward_operators import NeuralQuantileRegression\n",
    "from datasets import FNLVQR_Banana, FNLVQR_Glasses, FNLVQR_Star, FunnelDistribution\n",
    "\n",
    "color_map = matplotlib.colormaps['viridis']\n",
    "tensor_parameters = dict(dtype=torch.float64, device=torch.device(\"cpu\"))\n",
    "\n",
    "model = AmortizedNeuralQuantileRegression.load_class(\"../../experiments_full_16_09_2025/fnlvqr_banana/amortized_neural_quantile_regression_y/weights.pth\")\n",
    "model.to(**tensor_parameters)\n",
    "\n",
    "size = 30\n",
    "epsilon = 0.1\n",
    "alpha = 0.5\n",
    "\n",
    "xlim=(-4-epsilon, 4+epsilon)\n",
    "ylim=(0-epsilon, 2+epsilon)\n",
    "number_of_points_to_sample = 150\n",
    "\n",
    "dataset = FNLVQR_Banana(tensor_parameters=tensor_parameters)\n",
    "\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "fig = plt.figure(figsize=(20, 5))\n",
    "\n",
    "ax = fig.add_subplot(1, 4, 1)\n",
    "ax.set_xlim(*xlim)\n",
    "ax.set_ylim(*ylim)\n",
    "conditional_value = 0.8 + (2.4 * (1 / 2.))\n",
    "X = torch.tensor([[conditional_value]]).repeat(number_of_points_to_sample, 1).to(**tensor_parameters)\n",
    "X, Y = dataset.sample_conditional(X)\n",
    "U = torch.randn_like(Y)\n",
    "Y_approximated = model.push_u_given_x(U, X)\n",
    "\n",
    "ax.scatter(Y[:, 1], Y[:, 0], s=size, color=color_map(0.2), marker=\"X\", label=\"Ground Truth\")\n",
    "ax.scatter(Y_approximated[:, 1], Y_approximated[:, 0], s=size, color=color_map(0.7), marker=\"*\", label=\"C-NQR$_Y$\", alpha=0.5)\n",
    "ax.legend()\n",
    "ax.set_title(f\"Banana, X={conditional_value}\")\n",
    "ax.set_xlabel(\"Y$_0$\")\n",
    "ax.set_ylabel(\"Y$_1$\")\n",
    "ax.tick_params(labelleft=False, labelbottom=False)\n",
    "\n",
    "\n",
    "dataset = FNLVQR_Star(tensor_parameters=tensor_parameters)\n",
    "model = AmortizedNeuralQuantileRegression.load_class(\"../../experiments_full_16_09_2025/fnlvqr_star/amortized_neural_quantile_regression_u/weights.pth\")\n",
    "model.to(**tensor_parameters)\n",
    "xlim = (-6, 6)\n",
    "ylim = (-6, 6)\n",
    "\n",
    "ax = fig.add_subplot(1, 4, 2)\n",
    "ax.set_xlim(*xlim)\n",
    "ax.set_ylim(*ylim)\n",
    "conditional_value =  (2 * torch.pi / 3) * (1 / 2)\n",
    "X = torch.tensor([[conditional_value]]).repeat(number_of_points_to_sample, 1).to(**tensor_parameters)\n",
    "X, Y = dataset.sample_conditional(X)\n",
    "U = torch.randn_like(Y)\n",
    "Y_approximated = model.push_u_given_x(U, X)\n",
    "ax.scatter(Y[:, 1], Y[:, 0], s=size, color=color_map(0.2), marker=\"X\", label=\"Ground Truth\")\n",
    "ax.scatter(Y_approximated[:, 1], Y_approximated[:, 0], s=size, color=color_map(0.7), marker=\"*\", label=\"C-NQR$_Y$\", alpha=0.5)\n",
    "\n",
    "ax.set_xlabel(\"Y$_0$\")\n",
    "ax.set_ylabel(\"Y$_1$\")\n",
    "ax.tick_params(labelleft=False, labelbottom=False)\n",
    "ax.set_title(f\"Star, X={conditional_value:.3f}\")\n",
    "ax.tick_params(labelleft=False)\n",
    "\n",
    "dataset = FNLVQR_Glasses(tensor_parameters=tensor_parameters)\n",
    "model = NeuralQuantileRegression.load_class(\"../../experiments_full_16_09_2025/fnlvqr_glasses/neural_quantile_regression_y/weights.pth\")\n",
    "model.to(**tensor_parameters)\n",
    "xlim = (-0.2, 1.2)\n",
    "ylim = (-6, 10)\n",
    "ax = fig.add_subplot(1, 4, 3)\n",
    "ax.set_xlim(*xlim)\n",
    "ax.set_ylim(*ylim)\n",
    "conditional_value =  (2 * torch.pi / 3) * (1 / 2)\n",
    "X, Y = dataset.sample_joint(1000)\n",
    "U = torch.randn_like(Y)\n",
    "Y_approximated = model.push_u_given_x(U, X)\n",
    "\n",
    "\n",
    "ax.scatter(X.flatten(), Y.flatten(), s=size, color=color_map(0.2), marker=\"X\", label=\"Ground Truth\")\n",
    "ax.scatter(X.flatten(), Y_approximated.flatten(), s=size, color=color_map(0.7), marker=\"*\", label=\"C-NQR$_Y$\", alpha=0.5)\n",
    "ax.set_title(f\"Glasses, X={conditional_value:.3f}\")\n",
    "ax.set_xlabel(\"X\")\n",
    "ax.set_ylabel(\"Y\")\n",
    "ax.tick_params(labelleft=False, labelbottom=False)\n",
    "\n",
    "dataset = FunnelDistribution(tensor_parameters=tensor_parameters)\n",
    "model = AmortizedNeuralQuantileRegression.load_class(\"../../experiments_full_14_09_2025/funnel_1/amortized_neural_quantile_regression_y/weights.pth\")\n",
    "model.to(**tensor_parameters)\n",
    "\n",
    "xlim = (-10, 6)\n",
    "ylim = (-10, 10)\n",
    "ax = fig.add_subplot(1, 4, 4)\n",
    "ax.set_xlim(*xlim)\n",
    "ax.set_ylim(*ylim)\n",
    "conditional_value =  (2 * torch.pi / 3) * (1 / 2)\n",
    "X, Y = dataset.sample_joint(1000)\n",
    "U = torch.randn_like(Y)\n",
    "Y_approximated = model.push_u_given_x(U, X)\n",
    "\n",
    "ax.scatter(X.flatten(), Y.flatten(), s=size, color=color_map(0.2), marker=\"X\", label=\"Ground Truth\")\n",
    "ax.scatter(X.flatten(), Y_approximated.flatten(), s=size, color=color_map(0.7), marker=\"*\", label=\"C-NQR$_Y$\", alpha=0.5)\n",
    "ax.set_title(f\"Funnel, X={conditional_value:.3f}\")\n",
    "ax.set_xlabel(\"X\")\n",
    "ax.set_ylabel(\"Y\")\n",
    "ax.tick_params(labelleft=False, labelbottom=False)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"images/heatmap.pdf\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
