{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "51d10559",
   "metadata": {},
   "outputs": [],
   "source": [
    "from high_dimensionality_distances import *\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e65dd51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "dims = [2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000]\n",
    "t_values = [0.01, 0.1, 'auto', 'sqrt_auto']  # 'auto' means t=1/D\n",
    "n_trials = 100\n",
    "n_samples = 500\n",
    "# sigma = 1.0\n",
    "sigma_values = [1.0, 'sqrt_auto']\n",
    "mean_difference = 2.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c5a0f221",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running experiments... Device: cpu, Mean difference: 2.0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m mmd_stats, wass_stats, mag_stats_dict = \u001b[43mexperiment\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m      2\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      3\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdims\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdims\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      4\u001b[39m \u001b[43m    \u001b[49m\u001b[43mt_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43mt_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      5\u001b[39m \u001b[43m    \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      6\u001b[39m \u001b[43m    \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      7\u001b[39m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# sigma=sigma,\u001b[39;49;00m\n\u001b[32m      8\u001b[39m \u001b[43m    \u001b[49m\u001b[43msigma_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43msigma_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      9\u001b[39m \u001b[43m    \u001b[49m\u001b[43mmean_difference\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmean_difference\u001b[49m\n\u001b[32m     10\u001b[39m \u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/HighDimensionalityCurse/high_dimensionality_distances.py:101\u001b[39m, in \u001b[36mexperiment\u001b[39m\u001b[34m(device, dims, t_values, sigma_values, n_trials, n_samples, mean_difference)\u001b[39m\n\u001b[32m     99\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m t_values:\n\u001b[32m    100\u001b[39m         actual_t = \u001b[32m1.0\u001b[39m / d \u001b[38;5;28;01mif\u001b[39;00m t == \u001b[33m'\u001b[39m\u001b[33mauto\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m (\u001b[32m1.0\u001b[39m / (d**\u001b[32m0.5\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m t == \u001b[33m'\u001b[39m\u001b[33msqrt_auto\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m t)\n\u001b[32m--> \u001b[39m\u001b[32m101\u001b[39m         mag = \u001b[43mnorm_diff_magnitude_distance_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m=\u001b[49m\u001b[43mactual_t\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnormalize\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m    102\u001b[39m         mag_trials_dict[t].append(mag)\n\u001b[32m    104\u001b[39m \u001b[38;5;66;03m# Compute statistics for MMD\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/magnitude.py:229\u001b[39m, in \u001b[36mnorm_diff_magnitude_distance_grad\u001b[39m\u001b[34m(points0, points1, device, points_all, t, normalize, eps)\u001b[39m\n\u001b[32m    227\u001b[39m dist_all = torch.cdist(points_all, points_all, p=\u001b[32m2\u001b[39m) * t\n\u001b[32m    228\u001b[39m dist0 = torch.cdist(points0, points0, p=\u001b[32m2\u001b[39m) * t\n\u001b[32m--> \u001b[39m\u001b[32m229\u001b[39m dist1 = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcdist\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpoints1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpoints1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m)\u001b[49m * t\n\u001b[32m    231\u001b[39m eye_all = torch.eye(dist_all.size(\u001b[32m0\u001b[39m), device=device, dtype=torch.float64)\n\u001b[32m    232\u001b[39m eye0 = torch.eye(dist0.size(\u001b[32m0\u001b[39m), device=device, dtype=torch.float64)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/functional.py:1222\u001b[39m, in \u001b[36mcdist\u001b[39m\u001b[34m(x1, x2, p, compute_mode)\u001b[39m\n\u001b[32m   1219\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m   1220\u001b[39m         cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)\n\u001b[32m   1221\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m compute_mode == \u001b[33m'\u001b[39m\u001b[33muse_mm_for_euclid_dist_if_necessary\u001b[39m\u001b[33m'\u001b[39m:\n\u001b[32m-> \u001b[39m\u001b[32m1222\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_VF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcdist\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[32m   1223\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m compute_mode == \u001b[33m'\u001b[39m\u001b[33muse_mm_for_euclid_dist\u001b[39m\u001b[33m'\u001b[39m:\n\u001b[32m   1224\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m _VF.cdist(x1, x2, p, \u001b[32m1\u001b[39m)  \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "mmd_stats, wass_stats, mag_stats_dict = experiment(\n",
    "    device=device,\n",
    "    dims=dims,\n",
    "    t_values=t_values,\n",
    "    n_trials=n_trials,\n",
    "    n_samples=n_samples,\n",
    "    # sigma=sigma,\n",
    "    sigma_values=sigma_values,\n",
    "    mean_difference=mean_difference\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1781674f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: experiments_meandiff_10.0/plots/mean_dims2-1000_n500_trials100.pgf\n",
      "Saved: experiments_meandiff_10.0/plots/mean_dims2-1000_n500_trials100.png\n",
      "Saved: experiments_meandiff_10.0/plots/mean_log_dims2-1000_n500_trials100.pgf\n",
      "Saved: experiments_meandiff_10.0/plots/mean_log_dims2-1000_n500_trials100.png\n",
      "Saved: experiments_meandiff_10.0/plots/std_dims2-1000_n500_trials100.pgf\n",
      "Saved: experiments_meandiff_10.0/plots/std_dims2-1000_n500_trials100.png\n",
      "Saved: experiments_meandiff_10.0/plots/cv_dims2-1000_n500_trials100.pgf\n",
      "Saved: experiments_meandiff_10.0/plots/cv_dims2-1000_n500_trials100.png\n",
      "Saved: experiments_meandiff_10.0/plots/errorbar_dims2-1000_n500_trials100.pgf\n",
      "Saved: experiments_meandiff_10.0/plots/errorbar_dims2-1000_n500_trials100.png\n"
     ]
    }
   ],
   "source": [
    "base_dir = 'experiments_meandiff_10.0'\n",
    "mean_difference = 10.0\n",
    "dims, mmd_stats_dict, wass_stats, mag_stats_dict = load_experiment_notes(base_dir)\n",
    "\n",
    "# Get the actual sigma values from the loaded data\n",
    "sigma_values = list(mmd_stats_dict.keys())\n",
    "# Plot results\n",
    "# plot_single(dims, mmd_stats, mag_stats_dict, n_trials, n_samples, sigma, mean_difference)\n",
    "plot_single(dims, mmd_stats_dict, wass_stats, mag_stats_dict, n_trials, n_samples, sigma_values, mean_difference)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
