{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b89bfc33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['per_elm_energy', 'E_isolated', 'heat_of_formation', 'r2_score', 'mean_abs_error', 'include_hof', 'dataset_stats', 'uma_per_elm_energy'])\n"
     ]
    }
   ],
   "source": [
    "#load a pkl file\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.read_pickle(\"per_atom_energy_hof_False.pkl\")\n",
    "print((df.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9bb820b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'H': np.float64(-16.001316695716987), 'He': np.float64(-82.01087430838089), 'Li': np.float64(-201.82956857988097), 'Be': np.float64(-392.37382000025474), 'B': np.float64(-678.7577802660586), 'C': np.float64(-1037.0951083942327), 'N': np.float64(-1490.3891424843594), 'O': np.float64(-2047.6980007878883), 'F': np.float64(-2717.4000859159146), 'Ne': np.float64(-3511.9177958942937), 'Na': np.float64(-4412.34353756339), 'Mg': np.float64(-5442.048852490966), 'Al': np.float64(-6589.238285511481), 'Si': np.float64(-7880.842380936224), 'P': np.float64(-9290.366602061926), 'S': np.float64(-10834.813633824462), 'Cl': np.float64(-12522.319387833275), 'Ar': np.float64(-14357.418706857397), 'K': np.float64(-16323.294755148576), 'Ca': np.float64(-18435.01990024982), 'Sc': np.float64(-20696.697448280793), 'Ti': np.float64(-23112.143048514245), 'V': np.float64(-25682.83219323889), 'Cr': np.float64(-28416.863013932467), 'Mn': np.float64(-31314.828666123234), 'Fe': np.float64(-34381.61625929281), 'Co': np.float64(-37622.0970301371), 'Ni': np.float64(-41038.77074996302), 'Cu': np.float64(-44633.87178654121), 'Zn': np.float64(-48413.89305602666), 'Ga': np.float64(-52371.64723336532), 'Ge': np.float64(-56519.32612658109), 'As': np.float64(-60840.896102839535), 'Se': np.float64(-65348.22274342182), 'Br': np.float64(-70042.31301581895), 'Kr': np.float64(-74932.45491967947), 'Rb': np.float64(-653.3100819462285), 'Sr': np.float64(-831.4256441321643), 'Y': np.float64(-1038.3067074649807), 'Zr': np.float64(-1270.5802453884555), 'Nb': np.float64(-1546.2069619314702), 'Mo': np.float64(-1844.151869397674), 'Tc': np.float64(-2185.315025224292), 'Ru': np.float64(-2572.1188857512716), 'Rh': np.float64(-2998.6928887542217), 'Pd': np.float64(-3471.81081638118), 'Ag': np.float64(-3993.2282048969632), 'Cd': np.float64(-4560.650406219408), 'In': np.float64(-5163.202864090282), 'Sn': np.float64(-5822.573487382639), 'Sb': np.float64(-6540.2964462130585), 'Te': np.float64(-7296.665490756577), 'I': np.float64(-8099.4071751019965), 'Xe': np.float64(-8964.710394463762), 'Cs': np.float64(-546.2179067379832), 'Ba': np.float64(-689.6398798417333), 'La': np.float64(-850.4805849513325), 'Ce': np.float64(-12917.08295504127), 'Pr': np.float64(-14058.981210064643), 'Nd': np.float64(-15267.41531050162), 'Pm': np.float64(-16544.445729634244), 'Sm': np.float64(-17894.054958218647), 'Eu': np.float64(-19321.404884189622), 'Gd': np.float64(-20822.18831205945), 'Tb': np.float64(-22422.829984851378), 'Dy': np.float64(-24073.462031913972), 'Ho': np.float64(-25787.271913480094), 'Er': np.float64(-27608.607061532348), 'Tm': np.float64(-29517.834852414315), 'Yb': np.float64(-31519.198453963163), 'Lu': np.float64(-33610.154761607024), 'Hf': np.float64(-1293.3349595217703), 'Ta': np.float64(-1542.107003487292), 'W': np.float64(-1813.074616850505), 'Re': np.float64(-2116.0691118518926), 'Os': np.float64(-2451.555058191304), 'Ir': np.float64(-2827.017838069772), 'Pt': np.float64(-3237.345589780805), 'Au': np.float64(-3682.7850194399866), 'Hg': np.float64(-4168.118509910375), 'Tl': np.float64(-4683.015643495902), 'Pb': np.float64(-5238.7226964026495), 'Bi': np.float64(-5832.098572286902), 'Po': np.float64(-6469.07296), 'At': np.float64(-7140.86455), 'Rn': np.float64(-7854.60638), 'Fr': np.float64(0.0), 'Ra': np.float64(0.0), 'Ac': np.float64(0.0), 'Th': np.float64(0.0), 'Pa': np.float64(0.0), 'U': np.float64(0.0)}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# New reference values\n",
    "new_values = [\n",
    "    -16.001316695716987, -82.01087430838089, -201.82956857988097, -392.37382000025474,\n",
    "    -678.7577802660586, -1037.0951083942327, -1490.3891424843594, -2047.6980007878883,\n",
    "    -2717.4000859159146, -3511.9177958942937, -4412.34353756339, -5442.048852490966,\n",
    "    -6589.238285511481, -7880.842380936224, -9290.366602061926, -10834.813633824462,\n",
    "    -12522.319387833275, -14357.418706857397, -16323.294755148576, -18435.01990024982,\n",
    "    -20696.697448280793, -23112.143048514245, -25682.83219323889, -28416.863013932467,\n",
    "    -31314.828666123234, -34381.61625929281, -37622.0970301371, -41038.77074996302,\n",
    "    -44633.87178654121, -48413.89305602666, -52371.64723336532, -56519.32612658109,\n",
    "    -60840.896102839535, -65348.22274342182, -70042.31301581895, -74932.45491967947,\n",
    "    -653.3100819462285, -831.4256441321643, -1038.3067074649807, -1270.5802453884555,\n",
    "    -1546.2069619314702, -1844.151869397674, -2185.315025224292, -2572.1188857512716,\n",
    "    -2998.6928887542217, -3471.81081638118, -3993.2282048969632, -4560.650406219408,\n",
    "    -5163.202864090282, -5822.573487382639, -6540.2964462130585, -7296.665490756577,\n",
    "    -8099.4071751019965, -8964.710394463762, -546.2179067379832, -689.6398798417333,\n",
    "    -850.4805849513325, -12917.08295504127, -14058.981210064643, -15267.41531050162,\n",
    "    -16544.445729634244, -17894.054958218647, -19321.404884189622, -20822.18831205945,\n",
    "    -22422.829984851378, -24073.462031913972, -25787.271913480094, -27608.607061532348,\n",
    "    -29517.834852414315, -31519.198453963163, -33610.154761607024, -1293.3349595217703,\n",
    "    -1542.107003487292, -1813.074616850505, -2116.0691118518926, -2451.555058191304,\n",
    "    -2827.017838069772, -3237.345589780805, -3682.7850194399866, -4168.118509910375,\n",
    "    -4683.015643495902, -5238.7226964026495, -5832.098572286902, -6469.07296,\n",
    "    -7140.86455, -7854.60638\n",
    "]\n",
    "\n",
    "# Original element dictionary (keeping the same order)\n",
    "original_dict = {'H': np.float64(-16.09954261779785), 'He': np.float64(-0.00390625), 'Li': np.float64(-208.0673828125), 'Be': np.float64(0.00390625), 'B': np.float64(-682.095703125), 'C': np.float64(-1036.865234375), 'N': np.float64(-1489.73828125), 'O': np.float64(-2047.26611328125), 'F': np.float64(-2717.69921875), 'Ne': np.float64(0.0), 'Na': np.float64(-4419.19921875), 'Mg': np.float64(-5447.46875), 'Al': np.float64(0.0), 'Si': np.float64(-7881.47265625), 'P': np.float64(-9292.8857421875), 'S': np.float64(-10834.837890625), 'Cl': np.float64(-12522.71484375), 'Ar': np.float64(0.0), 'K': np.float64(-16328.71484375), 'Ca': np.float64(-18429.73828125), 'Sc': np.float64(0.0), 'Ti': np.float64(0.0), 'V': np.float64(0.0), 'Cr': np.float64(0.0), 'Mn': np.float64(0.0), 'Fe': np.float64(0.0), 'Co': np.float64(0.0), 'Ni': np.float64(0.0), 'Cu': np.float64(0.0), 'Zn': np.float64(0.0), 'Ga': np.float64(0.0), 'Ge': np.float64(0.0), 'As': np.float64(0.0), 'Se': np.float64(0.0), 'Br': np.float64(-70042.6171875), 'Kr': np.float64(0.0), 'Rb': np.float64(0.0), 'Sr': np.float64(0.0), 'Y': np.float64(0.0), 'Zr': np.float64(0.0), 'Nb': np.float64(0.0), 'Mo': np.float64(0.0), 'Tc': np.float64(0.0), 'Ru': np.float64(0.0), 'Rh': np.float64(0.0), 'Pd': np.float64(0.0), 'Ag': np.float64(0.0), 'Cd': np.float64(0.0), 'In': np.float64(0.0), 'Sn': np.float64(0.0), 'Sb': np.float64(0.0), 'Te': np.float64(0.0), 'I': np.float64(-8102.0537109375), 'Xe': np.float64(0.0), 'Cs': np.float64(0.0), 'Ba': np.float64(0.0), 'La': np.float64(0.0), 'Ce': np.float64(0.0), 'Pr': np.float64(0.0), 'Nd': np.float64(0.0), 'Pm': np.float64(0.0), 'Sm': np.float64(0.0), 'Eu': np.float64(0.0), 'Gd': np.float64(0.0), 'Tb': np.float64(0.0), 'Dy': np.float64(0.0), 'Ho': np.float64(0.0), 'Er': np.float64(0.0), 'Tm': np.float64(0.0), 'Yb': np.float64(0.0), 'Lu': np.float64(0.0), 'Hf': np.float64(0.0), 'Ta': np.float64(0.0), 'W': np.float64(0.0), 'Re': np.float64(0.0), 'Os': np.float64(0.0), 'Ir': np.float64(0.0), 'Pt': np.float64(0.0), 'Au': np.float64(0.0), 'Hg': np.float64(0.0), 'Tl': np.float64(0.0), 'Pb': np.float64(0.0), 'Bi': np.float64(0.0), 'Po': np.float64(0.0), 'At': np.float64(0.0), 'Rn': np.float64(0.0), 'Fr': np.float64(0.0), 'Ra': np.float64(0.0), 'Ac': np.float64(0.0), 'Th': np.float64(0.0), 'Pa': np.float64(0.0), 'U': np.float64(0.0)}\n",
    "\n",
    "# Update dictionary with new values while preserving order\n",
    "updated_dict = {}\n",
    "for i, (element, _) in enumerate(original_dict.items()):\n",
    "    if i < len(new_values):\n",
    "        updated_dict[element] = np.float64(new_values[i])\n",
    "    else:\n",
    "        updated_dict[element] = np.float64(0.0)  # Keep 0.0 for remaining elements\n",
    "\n",
    "print(updated_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b8f1cbd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# lets add this to the dataframe\n",
    "df['uma_per_elm_energy'] = updated_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ca21c702",
   "metadata": {},
   "outputs": [],
   "source": [
    "#lets save the updated dataframe\n",
    "#save the dict as pkl\n",
    "import pickle\n",
    "\n",
    "with open(\"per_atom_energy_hof_False.pkl\", \"wb\") as f:\n",
    "    pickle.dump(df, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
