{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9b957e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T20:40:32.747304Z",
     "start_time": "2025-01-20T20:40:29.330667Z"
    }
   },
   "outputs": [],
   "source": [
    "!python organize_dataset.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089868ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python dataset-combine.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beb55b6a",
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-07-09T18:26:28.479878Z"
    },
    "jupyter": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "!python train.py --model_type FM --dist_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf899b707ba3830b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T14:04:38.641098Z",
     "start_time": "2025-09-15T14:00:18.426293Z"
    }
   },
   "outputs": [],
   "source": [
    "!python unfold.py --model_type FM --dist_info "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb6eddc58d09ee6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T14:06:00.387209Z",
     "start_time": "2025-09-15T14:05:27.296723Z"
    }
   },
   "outputs": [],
   "source": [
    "!python plot.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da5e8355230467a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T14:06:30.260573Z",
     "start_time": "2025-09-15T14:06:28.573740Z"
    }
   },
   "outputs": [],
   "source": [
    "##%%\n",
    "import numpy as np\n",
    "# import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "# import configs\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "\n",
    "import pickle\n",
    "\n",
    "dict_name = 'distance_record_res'+'.pkl'\n",
    "vec = [\"pT\", \"eta\", \"phi\", \"E\", 'px', 'py','pz']\n",
    "res_all = []\n",
    "pickle_name = './plots/_real25_batchsize_permutation_invariantNN/'+dict_name\n",
    "# pickle_name2 = path + configs.dataset_name + '_' + configs.train_index_str + '_res.pkl'\n",
    "pickle_name_cDDPM = './plots/real25_cDDPM/'+dict_name\n",
    "pickle_name_cFM = './plots/real25_cFM/'+dict_name\n",
    "pickle_name_Omni = './plots/real25_Omni/'+dict_name\n",
    "pickle_name_SB = './plots/real25_SB/'+dict_name\n",
    "pickle_name_Omni_combine = './plots/real25_Omni_combine/'+dict_name\n",
    "pickle_name2 = './plots/real_condon_y1/'+dict_name\n",
    "pickle_name3 = './plots/real_condon_yall/'+dict_name\n",
    "pickle_name4 = './plots/real25_FM/'+dict_name\n",
    "\n",
    "with open(pickle_name, 'rb') as file:\n",
    "    data = pickle.load(file)\n",
    "with open(pickle_name2, 'rb') as file2:\n",
    "    data2 = pickle.load(file2)\n",
    "with open(pickle_name3, 'rb') as file3:\n",
    "    data3 = pickle.load(file3)\n",
    "with open(pickle_name4, 'rb') as file4:\n",
    "    data4 = pickle.load(file4)\n",
    "with open(pickle_name_cDDPM, 'rb') as file5:\n",
    "    data5 = pickle.load(file5)\n",
    "with open(pickle_name_cFM, 'rb') as file6:\n",
    "    data6 = pickle.load(file6)\n",
    "with open(pickle_name_Omni, 'rb') as file7:\n",
    "    data7 = pickle.load(file7)\n",
    "with open(pickle_name_Omni_combine, 'rb') as file8:\n",
    "    data8 = pickle.load(file8)\n",
    "with open(pickle_name_SB, 'rb') as file9:\n",
    "    data9 = pickle.load(file9)\n",
    "\n",
    "method_group = ['detector','cDDPM', 'cFM' , 'yall', 'y1', 'tDDPM' ,'tFM', 'Omnifold_best', 'Omnifold_combine', 'SB']\n",
    "\n",
    "dataset_group = ['lepqua_NNPDF23lo0130', 'ttbar_CT14lo_vincia', 'wjets_CT14lo', 'zjets_CTEQ6L1']\n",
    "for i in range(len(vec)):\n",
    "    # print(vec[i])\n",
    "    unfold_vec = np.array([])\n",
    "    reco_vec = np.array([])\n",
    "    unfold_vec_cDDPM = np.array([])\n",
    "    unfold_vec_cFM = np.array([])\n",
    "    unfold_vec_Omni = np.array([])\n",
    "    unfold_vec_SB = np.array([])\n",
    "    unfold_vec_Omni_combine = np.array([])\n",
    "    unfold_vec2 = np.array([])\n",
    "    unfold_vec3 = np.array([])\n",
    "    unfold_vec4 = np.array([])\n",
    "    for dataset_name in dataset_group:\n",
    "        unfoldata = data[dataset_name  + vec[i] + 'unfold']\n",
    "        recodata = data[dataset_name  + vec[i] + 'reco']\n",
    "        unfoldata2 = data2[dataset_name  + vec[i] + 'unfold']\n",
    "        unfoldata3 = data3[dataset_name  + vec[i] + 'unfold']\n",
    "        unfoldata4 = data4[dataset_name  + vec[i] + 'unfold']\n",
    "        unfolddata_cDDPM = data5[dataset_name + vec[i] + 'unfold']\n",
    "        unfolddata_cFM = data6[dataset_name + vec[i] + 'unfold']\n",
    "        unfolddata_Omni = data7[dataset_name + vec[i] + 'unfold']\n",
    "        unfolddata_Omni_combine = data8[dataset_name + vec[i] + 'unfold']\n",
    "        unfolddata_SB = data9[dataset_name + vec[i] + 'unfold']\n",
    "        \n",
    "        unfold_vec = np.append(unfold_vec, unfoldata)\n",
    "        reco_vec = np.append(reco_vec, recodata)\n",
    "        unfold_vec2 = np.append(unfold_vec2, unfoldata2)\n",
    "        unfold_vec3 = np.append(unfold_vec3, unfoldata3)\n",
    "        unfold_vec4 = np.append(unfold_vec4, unfoldata4)\n",
    "        unfold_vec_cDDPM = np.append(unfold_vec_cDDPM, unfolddata_cDDPM)\n",
    "        unfold_vec_cFM = np.append(unfold_vec_cFM, unfolddata_cFM)\n",
    "        unfold_vec_Omni = np.append(unfold_vec_Omni, unfolddata_Omni)\n",
    "        unfold_vec_Omni_combine = np.append(unfold_vec_Omni_combine, unfolddata_Omni_combine)\n",
    "        unfold_vec_SB = np.append(unfold_vec_SB, unfolddata_SB)\n",
    "        \n",
    "    res = [reco_vec,unfold_vec_cDDPM, unfold_vec_cFM,unfold_vec3, unfold_vec2,unfold_vec, unfold_vec4, unfold_vec_Omni, unfold_vec_Omni_combine, unfold_vec_SB]\n",
    "    # \n",
    "    res_all.append(np.array(res))\n",
    "    # print(res)\n",
    "    x = np.arange((4))\n",
    "    width = 0.09\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    for j in range(len(res)):\n",
    "\n",
    "        plt.bar(x + j * width, res[j], width=width, label=method_group[j])\n",
    "\n",
    "    plt.semilogy()\n",
    "    plt.xticks(x + width * (len(res)-1)/2, dataset_group)\n",
    "    plt.legend()\n",
    "    plt.ylabel('Wasserstein Distances')\n",
    "\n",
    "    plt.title(vec[i])\n",
    "    plt.savefig('./plots/compare/'+vec[i]+'.png')\n",
    "    plt.show()\n",
    "    \n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
