{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3eade861",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n",
      "[Seed fixed to 0]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from ADEN import ADEN\n",
    "from torchinfo import summary\n",
    "from TestCaseGenerator import data_RLClustering\n",
    "from ADENTrain import TrainAnneal\n",
    "import utils\n",
    "from Env import ClusteringEnvNumpy, ClusteringEnvTorch\n",
    "from ClusteringGroundTruth import cluster_gt\n",
    "import pickle\n",
    "from datetime import datetime\n",
    "from Plotter import PlotClustering\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Using device:\", device)\n",
    "utils.set_seed(0)  # for reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b85fadb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from TestCaseGenerator import data_RLClustering\n",
    "\n",
    "# X, M, T_P, N, d = data_RLClustering(4)\n",
    "# rho = np.ones(N) / N  # Uniform weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4fa9c5a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "address = f\"MATLAB Codes/UTD19_London.mat\"\n",
    "# read as numpy array\n",
    "data = scipy.io.loadmat(address)\n",
    "locs = data['Xz']\n",
    "# normalize locs to be in [0,1]x[0,1]\n",
    "X = (locs - np.min(locs, axis=0)) / (np.max(locs, axis=0) - np.min(locs, axis=0))\n",
    "N, d = X.shape\n",
    "M = 25\n",
    "rho = np.ones(N) / N  # Uniform weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dde445cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_params(s):\n",
    "    \"\"\"\n",
    "    Extract eps, gamma, zeta, and T from a string like:\n",
    "    'UDT_M50eps0.1gam0.0zet1.0T0.01D64_...'\n",
    "    Returns a dict with floats.\n",
    "    \"\"\"\n",
    "    pattern = (\n",
    "        r\"eps(?P<eps>[\\d.]+)\"\n",
    "        r\"gam(?P<gamma>[\\d.]+)\"\n",
    "        r\"zet(?P<zeta>[\\d.]+)\"\n",
    "        r\"T(?P<T>[\\d.]+)\"\n",
    "    )\n",
    "    m = re.search(pattern, s)\n",
    "    if not m:\n",
    "        raise ValueError(\"Could not parse parameters from string.\")\n",
    "    return {k: float(v) for k, v in m.groupdict().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c5afbc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get a list of all files inside Benchmark Folder\n",
    "import os\n",
    "import pandas as pd\n",
    "from ClusteringGroundTruth import distortion\n",
    "\n",
    "benchmark_folder = \"BenchmarkUDT\"\n",
    "HAS_GT = True\n",
    "all_files = os.listdir(benchmark_folder)\n",
    "all_files = [f for f in all_files if os.path.isfile(os.path.join(benchmark_folder, f))]\n",
    "results_df = pd.DataFrame(\n",
    "    columns=[\"eps\", \"gamma\", \"zeta\", \"T\", \"error_opt\", \"error_ig\"]\n",
    ")\n",
    "\n",
    "# loop over all files\n",
    "for file_name in all_files:\n",
    "    with open(os.path.join(benchmark_folder, file_name), \"rb\") as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    scenario_name = data[\"scenario_name\"]\n",
    "    # from scenario name read eps, gamma, zeta, T. Example: Benchmark_parametrizedTrue_eps0.1_gamma0.0_zeta0.5_T0.001\n",
    "    # eps = float(scenario_name.split(\"eps\")[1].split(\"_\")[0])\n",
    "    # gamma = float(scenario_name.split(\"gamma\")[1].split(\"_\")[0])\n",
    "    # zeta = float(scenario_name.split(\"zeta\")[1].split(\"_\")[0])\n",
    "    # T = float(scenario_name.split(\"T\")[2].split(\"_\")[0])\n",
    "    params = extract_params(scenario_name)\n",
    "    eps = params[\"eps\"]\n",
    "    gamma = params[\"gamma\"]\n",
    "    zeta = params[\"zeta\"]\n",
    "    T = params[\"T\"]\n",
    "    # print(\"eps:\", eps, \"gamma:\", gamma, \"zeta:\", zeta, \"T:\", T)\n",
    "    if T == 0.1:\n",
    "        continue\n",
    "    env = ClusteringEnvNumpy(\n",
    "        n_data=N,\n",
    "        n_clusters=M,\n",
    "        n_features=d,\n",
    "        parametrized=True,\n",
    "        eps=eps,\n",
    "        gamma=gamma,\n",
    "        zeta=zeta,\n",
    "        T=T,\n",
    "        T_p=0.0,\n",
    "    )\n",
    "    if HAS_GT:\n",
    "        Y_GT = data[\"Y_GT\"]\n",
    "        pi_GT = data[\"pi_GT\"]\n",
    "        if np.isnan(Y_GT).any() or np.isnan(pi_GT).any():\n",
    "            continue\n",
    "        distortion_gt = distortion(X, Y_GT, rho, pi_GT, env)\n",
    "    Y_opt = data[\"Y_opt\"]\n",
    "    pi_opt = data[\"pi_opt\"]\n",
    "    Y_ig = data[\"Y_ig\"]\n",
    "    pi_ig = data[\"pi_ig\"]\n",
    "    # if any of the above values contain NAN, skip this scenario\n",
    "\n",
    "    if np.isnan(Y_opt).any() or np.isnan(pi_opt).any():\n",
    "        continue\n",
    "    if np.isnan(Y_ig).any() or np.isnan(pi_ig).any():\n",
    "        continue\n",
    "\n",
    "    distortion_opt = distortion(X, Y_opt, rho, pi_opt, env)\n",
    "    distortion_ig = distortion(X, Y_ig, rho, pi_ig, env)\n",
    "    \n",
    "\n",
    "    error_opt = (distortion_opt - distortion_gt) / distortion_gt * 100\n",
    "    error_ig = (distortion_ig - distortion_gt) / distortion_gt * 100\n",
    "    # error_opt_ig = (distortion_ig - distortion_opt) / distortion_opt * 100\n",
    "\n",
    "    # Print the results\n",
    "    # print(\n",
    "    #     f\"Scenario: {scenario_name} error_opt: {error_opt:.2f}%, error_ig: {error_ig:.2f}%\"\n",
    "    # )\n",
    "    # based on eps, zeta, gamma, T, and the values of error_opt and error_ig, add a row to a pandas dataframe\n",
    "\n",
    "    # results_df = pd.concat(\n",
    "    #     [\n",
    "    #         results_df,\n",
    "    #         pd.DataFrame(\n",
    "    #             {\n",
    "    #                 # \"scenario_name\": [scenario_name],\n",
    "    #                 \"eps\": [eps],\n",
    "    #                 \"gamma\": [gamma],\n",
    "    #                 \"zeta\": [zeta],\n",
    "    #                 \"T\": [T],\n",
    "    #                 \"error_opt\": [error_opt],\n",
    "    #                 \"error_ig\": [error_ig],\n",
    "    #             }\n",
    "    #         ),\n",
    "    #     ],\n",
    "    #     ignore_index=True,\n",
    "    # )\n",
    "    print(scenario_name)\n",
    "    # print(error_opt_ig)\n",
    "    print(\"ig:{:.4f}, opt:{:.4f}, gt:{:.4f}\".format(distortion_ig, distortion_opt, distortion_gt))\n",
    "    print(\"error opt:{:.2f}%, error_ig:{:.2f}%\".format(error_opt, error_ig))\n",
    "\n",
    "    # PlotClustering(X, Y_opt, pi_opt, figsize=(6, 4),\n",
    "    # point_size=10,\n",
    "    # centroid_size=300,\n",
    "    # alpha=0.9,\n",
    "    # data_edge_color='white',\n",
    "    # cluster_edge_color='black', \n",
    "    # save_path=f\"Results/{scenario_name}.png\"\n",
    "    # )\n",
    "# print the dataframe up to 2 digits\n",
    "pd.set_option(\"display.precision\", 2)\n",
    "# SAVE results_df to a csv file with current date and time\n",
    "# results_df.to_csv(\n",
    "#     f\"benchmark_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv\", index=False\n",
    "# )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
