{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064b5b33-0350-47cd-873d-106ef3ee2037",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import random\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.mixture import BayesianGaussianMixture\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.datasets import make_circles, load_iris\n",
    "from sklearn import preprocessing\n",
    "dir_ = os.path.abspath(os.path.join(os.getcwd(), \"..\", \"..\"))\n",
    "os.chdir(dir_)\n",
    "# Local modules\n",
    "import utils\n",
    "import prior\n",
    "import transformer\n",
    "import main\n",
    "\n",
    "# Settings\n",
    "matplotlib.use(\"TkAgg\")\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "%matplotlib inline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773d9077-dc1c-4c7f-874c-f5a9bc4965f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/datasets/faithful.csv\"\n",
    "faithful_np = pd.read_csv(url, index_col=0).to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80c17b85-039d-4687-a564-bbd53c6b6a1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Using device: {torch.cuda.get_device_name(torch.cuda.current_device())}\")\n",
    "device = torch.device(\"cuda\")\n",
    "d_model, nhead, nhid, nlayers = 256, 4, 512, 4\n",
    "in_features = 2\n",
    "num_outputs = 10\n",
    "model = transformer.Transformer(d_model, nhead, nhid, nlayers,in_features=in_features, buckets_size=num_outputs).to(device)\n",
    "print(f\"total params:{sum(p.numel() for p in model.parameters())}\")\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "model.criterion = criterion\n",
    "\n",
    "\n",
    "checkpoint = torch.load(\"models/models_original/pfn_easy_2D.pt\", weights_only=True)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "model.eval() \n",
    "print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf7ec924-f239-402f-a35c-882326088a93",
   "metadata": {},
   "outputs": [],
   "source": [
    "faithful_standard = preprocessing.MinMaxScaler().fit_transform(faithful_np)\n",
    "faithful_tensor = torch.tensor(faithful_standard, dtype=torch.float32).to(device)\n",
    "faithful_tensor = faithful_tensor.unsqueeze(1)\n",
    "logits,cluster_output = model(faithful_tensor, torch.full((1,1), 0, dtype=torch.long, device=device))\n",
    "logits = logits.squeeze(1)\n",
    "cluster_output = cluster_output.cpu()\n",
    "predictions = torch.argmax(logits, dim=1)\n",
    "predictions = predictions.cpu()\n",
    "\n",
    "plt.scatter(faithful_np[:, 0] , faithful_np[:, 1], c=predictions)\n",
    "plt.title(\"Cluster-PFN prediction on The Old Faithful Dataset\")\n",
    "plt.xlabel(\"Eruption duration\")\n",
    "plt.ylabel(\"Waiting times\")\n",
    "plt.show()\n",
    "\n",
    "\n",
    "probs_tensor = F.softmax(cluster_output, dim=-1)  # still shape [1, 1, 10]\n",
    "probs = probs_tensor.squeeze().detach().numpy()\n",
    "bins = np.arange(1, 11)\n",
    "\n",
    "#Plot\n",
    "plt.figure(figsize=(8, 4))\n",
    "plt.bar(bins, probs, tick_label=bins)\n",
    "plt.xlabel(\"Number of clusters\")\n",
    "plt.ylabel(\"Probability\")\n",
    "plt.title(\"Probability distribution over the possible number of clusters in the dataset\")\n",
    "plt.ylim(0, max(probs) * 1.2)\n",
    "plt.grid(axis='y', linestyle='--', alpha=0.6)\n",
    "plt.savefig(\"old_faitful_cluster_prediction.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6638a082-79e5-4f3f-a6c3-5eaf1a41ca2c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cuda_test",
   "language": "python",
   "name": "cuda_test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
