from ase.io import read
import sys
import numpy as np
import os
import pandas as pd

MAD_BENCH_MAD_SETTINGS = (
    "/work/anon/anon/projects/mad-bench/data/mad-bench-mad-settings.xyz"
)
MAD_BENCH_MPTRJ_SETTINGS = (
    "/work/anon/anon/projects/mad-bench/data/mad-bench-mptrj-settings.xyz"
)
DATASETS = ["MAD", "MPtrj", "MatBench", "Alexandria", "OC2020", "SPICE", "MD22"]
DTYPE = "float64"

mad_bench_mad_settings = read(MAD_BENCH_MAD_SETTINGS, index=":")
mad_bench_mptrj_settings = read(MAD_BENCH_MPTRJ_SETTINGS, index=":")


def calculate_model_stats(model_name, predicted_atoms, target_atoms):
    model_energy_errors = {model_name: {}}
    model_forces_errors = {model_name: {}}
    dataset_info = np.array([a.info["dataset"] for a in target_atoms])
    for dataset in DATASETS:
        dataset_indices = np.where(dataset_info == dataset)[0]
        predicted_atoms_dataset = np.array(predicted_atoms, dtype=object)[
            dataset_indices
        ].tolist()
        target_atoms_dataset = np.array(target_atoms, dtype=object)[
            dataset_indices
        ].tolist()

        predicted_energies = np.array(
            [a.get_potential_energy() for a in predicted_atoms_dataset]
        )
        target_energies = np.array(
            [a.get_potential_energy() for a in target_atoms_dataset]
        )
        num_atoms = np.array([len(a) for a in target_atoms_dataset])
        predicted_energies_per_atom = predicted_energies / num_atoms
        target_energies_per_atom = target_energies / num_atoms
        energy_mae = (
            np.mean(np.abs(predicted_energies_per_atom - target_energies_per_atom))
            * 1000
        )  # meV/atom
        if dataset == "MatBench" and "mad" not in model_name.lower():
            forces_mae = np.nan
        else:
            predicted_forces = np.concatenate(
                [a.get_forces() for a in predicted_atoms_dataset]
            ).flatten()
            target_forces = np.concatenate(
                [a.get_forces() for a in target_atoms_dataset]
            ).flatten()
            forces_mae = (
                np.mean(np.abs(predicted_forces - target_forces)) * 1000
            )  # meV/Å
        print(f"\t{dataset}: {energy_mae:.2f} meV/atom, {forces_mae:.2f} meV/Å")

        model_energy_errors[model_name][dataset] = energy_mae
        model_forces_errors[model_name][dataset] = forces_mae
    return model_energy_errors, model_forces_errors


if __name__ == "__main__":
    model_paths = sys.argv[1:]
    target_atoms_mad_settings = read(MAD_BENCH_MAD_SETTINGS, index=":")
    target_atoms_mptrj_settings = read(MAD_BENCH_MPTRJ_SETTINGS, index=":")
    benchmark_energy_errors = {}
    benchmark_forces_errors = {}
    for model_path in model_paths:
        model_name = model_path.split(".")[1]
        print(f"Evaluating model: {model_name} from {model_path}")
        filename = (
            "predictions_float64.xyz" if DTYPE == "float64" else "predictions.xyz"
        )
        predicted_atoms = read(os.path.join(model_path, filename), index=":")
        if "mad" in model_name.lower():
            target_atoms = target_atoms_mad_settings
        else:
            target_atoms = target_atoms_mptrj_settings
        model_energy_errors, model_forces_errors = calculate_model_stats(
            model_name,
            predicted_atoms,
            target_atoms,
        )
        benchmark_energy_errors.update(model_energy_errors)
        benchmark_forces_errors.update(model_forces_errors)

    df = pd.DataFrame(columns=list(benchmark_energy_errors.keys()), index=DATASETS)
    for model_path in model_paths:
        model_name = model_path.split(".")[1]
        energy_errors = np.array(
            list(benchmark_energy_errors[model_name].values())
        ).round(1)
        forces_errors = np.array(
            list(benchmark_forces_errors[model_name].values())
        ).round(1)
        df[model_name] = [
            f"{energy_error}/{forces_error}"
            for energy_error, forces_error in zip(energy_errors, forces_errors)
        ]
    print("\nBenchmark Results (Energy MAE [meV/atom] / Forces MAE [meV/Å]):")
    print(df)
    filename = (
        "benchmark_results_float64.csv"
        if DTYPE == "float64"
        else "benchmark_results.csv"
    )
    df.to_csv(filename)
    print(f"\nBenchmark results saved to {filename}")
