"""Concatenate MACE results from multiple data files generated by slurm job array
into single file.
"""

# %%
import os
from glob import glob

import pandas as pd
from pymatgen.core import Structure
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatviz.enums import Key
from tqdm import tqdm

from matbench_discovery.data import as_dict_handler, df_wbm
from matbench_discovery.energy import calc_energy_from_e_refs, mp_elemental_ref_energies
from matbench_discovery.enums import DataFiles, MbdKey, Model, Task

__author__ = "Janosh Riebesell"
__date__ = "2023-03-01"


# %%
module_dir = os.path.dirname(__file__)
model_name = Model.mace_mpa_0.key
task_type = Task.IS2RE
date = "2025-01-30"
glob_pattern = f"{model_name}/{date}-wbm-{task_type}*/*.json.gz"
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
print(f"Found {len(file_paths):,} files for {glob_pattern = }")

e_form_mace_col = "e_form_per_atom_mace"
struct_col = "mace_structure"

dfs: dict[str, pd.DataFrame] = {}


# %%
for file_path in tqdm(file_paths):
    if file_path in dfs:
        continue
    df_i = pd.read_json(file_path).set_index(Key.mat_id)
    # drop trajectory to save memory
    dfs[file_path] = df_i.drop(columns="mace_trajectory", errors="ignore")

df_mace = pd.concat(dfs.values()).round(4)


# %%
wbm_cse_path = DataFiles.wbm_computed_structure_entries.path
df_wbm_cse = pd.read_json(wbm_cse_path, lines=True).set_index(Key.mat_id)

df_wbm_cse[Key.computed_structure_entry] = [
    ComputedStructureEntry.from_dict(dct)
    for dct in tqdm(df_wbm_cse[Key.computed_structure_entry], desc="Hydrate CSEs")
]


# %% transfer mace energies and relaxed structures WBM CSEs since MP2020 energy
# corrections applied below are structure-dependent (for oxides and sulfides)
cse: ComputedStructureEntry
for row in tqdm(df_mace.itertuples(), total=len(df_mace), desc="ML energies to CSEs"):
    mat_id, struct_dict, mace_energy, *_ = row
    mlip_struct = Structure.from_dict(struct_dict)
    cse = df_wbm_cse.loc[mat_id, Key.computed_structure_entry]
    cse._energy = mace_energy  # cse._energy is the uncorrected energy  # noqa: SLF001
    cse._structure = mlip_struct  # noqa: SLF001
    df_mace.loc[mat_id, Key.computed_structure_entry] = cse


# %% apply energy corrections
processed = MaterialsProject2020Compatibility().process_entries(
    df_mace[Key.computed_structure_entry], verbose=True, clean=True
)
if len(processed) != len(df_mace):
    raise ValueError(f"not all entries processed: {len(processed)=} {len(df_mace)=}")


# %% compute corrected formation energies
df_mace[Key.formula] = df_wbm[Key.formula]

print("Calculating formation energies")
e_form_list: dict[str, float] = {}
for mat_id, row in tqdm(df_mace.iterrows(), total=len(df_mace)):
    e_form = calc_energy_from_e_refs(
        row["formula"],
        ref_energies=mp_elemental_ref_energies,
        total_energy=row[Key.computed_structure_entry].energy,
    )
    e_form_list[mat_id] = e_form

df_mace[e_form_mace_col] = e_form_list
df_wbm[[*df_mace]] = df_mace


# %%
bad_mask = abs(df_wbm[e_form_mace_col] - df_wbm[MbdKey.e_form_dft]) > 5
n_preds = len(df_wbm[e_form_mace_col].dropna())
print(f"{sum(bad_mask)=} is {sum(bad_mask) / len(df_wbm):.2%} of {n_preds:,}")
out_path = file_paths[0].rsplit("/", 1)[0]
df_mace = df_mace.round(4)
df_mace.select_dtypes("number").to_csv(f"{out_path}.csv.gz")
df_mace.reset_index().to_json(
    f"{out_path}.json.gz", default_handler=as_dict_handler, orient="records", lines=True
)

# in_path = f"{module_dir}/2024-07-20-mace-wbm-IS2RE-FIRE"
# df_mace = pd.read_csv(f"{in_path}.csv.gz").set_index(Key.mat_id)
# df_mace = pd.read_json(f"{in_path}.json.gz").set_index(Key.mat_id)
