"""Standalone evaluation entrypoint for DEI-XC checkpoints.

This script can be launched either directly (``python scripts/deixc_evaluate.py``)
or through SEML.  It reconstructs the training pipeline that produced a saved
checkpoint, regenerates any missing DEI-XC auxiliary targets on the fly, and
writes per-sample metrics to CSV.  Key configuration knobs are exposed through
the SEML/ Sacred config defined in :func:`default_config`:

- ``logging``: output directory and optional CSV path.
- ``data``: dataset key, filtering options, SCF alignment, and reference method
  settings (matching the fields used during training).
- ``model``: overrides for the XC functional.  When evaluating a checkpoint, the
  saved YAML that lives next to ``params.flax`` takes precedence, but any values
  provided here can further override that config.
- ``initial_model_params``: ``load_from`` (required), checkpoint ``prefix``
  (e.g., ``dynamic_train``), and optional ``checkpoint_step``.
- ``evaluation``: which dataset splits to traverse (default: ``('test',)``).

All iteration happens on the main thread.  The helper ``LazyDeixcDataset`` wraps
the base dataset and materializes DEI-XC auxiliary files when they are missing.
"""

import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple

import grain.python as grain
import jax
import jax.numpy as jnp
import pandas as pd
from jax import config as jax_config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from deixc.data_generation.lazy_dataloader import (
    LazyDEIXCDataset,
    get_psys_and_lazy_dataloader,
)
from deixc.dataset import DEIXCDataset, DEIXCTargets
from deixc.scf import DerivativeInformedSelfConsistentFieldSolver
from egxc import dataloading, discretization, systems, xc_energy
from egxc.dataloading.dataloader import get_psys_and_dataloaders
from egxc.dataloading.io import (
    checkpoint_best_path,
    copy_directory_to_directory,
    results_directory,
    results_metrics_csv_path,
    results_summary_path,
)
from egxc.training.loss.density import (
    dipole_difference,
    get_coulomb_energy_error_fn,
    get_density_mean_field_error_fn,
)
from egxc.training.loss.field import (
    delta_field_fn,
    field_integral_measures,
    overlap_based_mae_surrogate,
    overlap_based_mse_surrogate,
)
from egxc.utils.typing import Alignment, BaseInitialGuess, MethodKey

compilation_cache.set_cache_dir('./caches/jax_compile/deixc_eval')

jax_config.update('jax_platform_name', 'gpu')
jax_config.update('jax_enable_x64', True)
jax_config.update('jax_default_matmul_precision', 'float32')

ex = Experiment()

CHECKPOINT_ROOT = 'ANONYMOUS_DIR'


@ex.config
def default_config():
    """Default Sacred/SEML configuration for the evaluation script."""
    logging = {
        'dir': None,  # output directory
        'name': None,  # name of individual evaluation run (subdirectory of dir)
        'output_csv': 'results.csv',
    }
    model = {
        'name': None,  # e.g. NNmGGA, XCdiff, etc.
    }
    base = {
        'use_density_fitting': True,
        'spin_restricted': True,
        'splits': ('test',),
    }
    evaluation_checkpoints = {
        'chkp_type': 'dynamic_train',
        'root_dir': None,  # Directory containing checkpoint subdirectories
        'name': None,  # Base name for checkpoints (e.g., 'orbital_free_distillation_')
        'indices': None,  # List of indices to evaluate (e.g., [1, 4, 5] or [1] for single)
    }
    basis = {
        'name': 'def2-TZVPD',  # def2-TZVPD
        'derivative': 1,
    }
    alignment = {
        'atom': 1,
        'basis': 1,
        'grid': 32 * 1024,
    }
    quadrature = {
        'level': 1,
    }
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',  # 'minao', 'atom'
        'initial_ref_density_method': None,  # if None, no reference density is used
        # {
        #     'key': 'ks_dft',
        #     'kwargs': {  # method_specific kwargs
        #         'xc_str': 'LDA',
        #         'basis': basis['name'],
        #         'backend': 'pyscf',  # pyscf or custom
        #         'use_eri_density_fitting': True,
        #         'use_exchange_density_fitting': True,  # LDA does not use hf exchange
        #         'quadrature_grid_level': 3,
        #         'spin_restricted': base['spin_restricted'],
        #     },
        # },
        'min_ref_density_interpolation': 0.0,
        'max_ref_density_interpolation': 0.0,
        'noise_eps': 0.0,
    }
    # Data configuration
    data = {
        'batch_size': 1,
        'shuffle': False,
        'lazy': True,
        'workers': 3,  # not used when lazy is True
        'data_split_seed': 0,  # random seed for splitting the dataset
        'n_test_samples': None,  # number of samples to use for testing, if None, use the full test set
        'preload': {
            'center': False,
        },
        'split': {'val_fraction': 0.1},
        'deixc_method_key': 'ks_dft',
        'align_scf_trajectory': None,
        'shift_dispersion': False,
        'deixc_method_kwargs': {
            'xc_str': 'B3LYP',
            'basis': basis['name'],
            'backend': 'custom',  # TODO: check
            'use_eri_density_fitting': base['use_density_fitting'],
            'use_exchange_density_fitting': base['use_density_fitting'],
            'spin_restricted': base['spin_restricted'],
            'quadrature_grid_level': quadrature['level'],
        },
    }
    # Solver configuration
    solver = {
        'method': 'scf',
        'convergence_threshold': 1e-7,  # in Hartree
        'kwargs': {
            'cycles': 15,
            'convergence_acceleration_method': 'DIIS',
        },
    }
    # Reference energy evaluation on model density
    reference_energy_on_model_density = {
        'enabled': True,  # if True, evaluate reference functional energy on model's final density
    }
    # SCF Acceleration Benchmark: measure speedup from warm-starting reference with learned density
    scf_acceleration = {
        'enabled': False,  # if True, run reference SCF from learned density and measure cycle savings
        'reference_scf_cycles': 15,  # max cycles for warm-start reference SCF
    }


@ex.named_config
def scan_functional():
    """Use SCAN as the model functional (no learned parameters)."""
    model = {'name': 'scan'}
    evaluation_checkpoints = {
        'indices': None,  # No checkpoints needed for traditional functionals
    }


@ex.named_config
def scaled_nagai2020():
    """
    Hidden dimension scaled to 23 from 100 in the original Nagai2020 functional.
    This is done to match the number of trainable parameters of the XCDiff functional.
    """
    model = {
        'name': 'nagai2020',
        'local_n_layers': 4,
        'local_hidden_dim': 23,
    }


@ex.named_config
def scaled_skala_mgga():
    # Skala requires true spin-resolved features; enforce unrestricted spin
    model = {
        'name': 'skala_mgga',
        'local_n_layers': 4,
        'local_hidden_dim': 22,
    }


@ex.named_config
def nagai2020_deeper():
    model = {
        'name': 'custom',
        'local': {
            'name': 'nagai2020',
            'local_n_layers': 6,
            'local_hidden_dim': 32,
        },
    }


@ex.named_config
def skala_mgga_deeper():
    model = {
        'name': 'custom',
        'local': {
            'name': 'skala_mgga',
            'local_n_layers': 6,
            'local_hidden_dim': 32,
        },
    }


@ex.named_config
def nagai2020_orbital_free():
    """
    Scaled to fit on a100 GPU
    """
    model = {
        'name': 'nagai2020_orbital_free',
        'local_n_layers': 5,
        'local_hidden_dim': 24,
    }


@ex.named_config
def egxc_orbital_free():
    model = {
        'name': 'custom',
        'local': {
            'name': 'nagai2020_orbital_free',
            'local_n_layers': 5,
            'local_hidden_dim': 24,
        },
    }
    model['gnn'] = {
        'type': 'nequip',
        'encoder': {
            'cutoff': 5.0,  # Angstrom
            'num_radial_filters': 32,
            'nuclei_partitioning': 'Exponential',
            '_quadrature_points_per_atom_scaling': 12,  # ablate
        },
        'kwargs': {
            'irreps_str': '128x0e + 64x1o',  # per-L message passing dims
            'output_irreps_str': '16x0e + 16x1o',  # per-L output dims
            'message_cutoff': 5.0,
            'n_radial_basis': 8,
            'layers': 3,
            'init_graph_readout_to_zero': True,
            'energy_graph_readout_hidden_dims': (256, 256, 1),
        },
    }
    model['non_locality'] = {
        'grid_feature_mode': 'reweighting_without_mGGA_feats',
        'graph_readout': True,
        'decoder': {
            'spatial_feature_dim': 16,  # output spatial feature dimension
        },
        'reweighting': {  # ablate these
            'layers': 3,
            'hidden_dim': 16,
            'init_scale': 0.1,
            'output_activation_type': 'None',  # TODO: try scaled_sigmoid again
        },
    }


@ex.named_config
def egxc():
    model = {
        'name': 'custom',
        'local': {
            'name': 'nagai2020',
            'local_n_layers': 6,
            'local_hidden_dim': 32,
        },
    }
    model['gnn'] = {
        'type': 'nequip',
        'encoder': {
            'cutoff': 5.0,  # Angstrom
            'num_radial_filters': 32,
            'nuclei_partitioning': 'Exponential',
            '_quadrature_points_per_atom_scaling': 12,  # ablate
        },
        'kwargs': {
            'irreps_str': '128x0e + 64x1o',  # per-L message passing dims
            'output_irreps_str': '16x0e + 16x1o',  # per-L output dims
            'message_cutoff': 5.0,
            'n_radial_basis': 8,
            'layers': 3,
            'init_graph_readout_to_zero': True,
            'energy_graph_readout_hidden_dims': (256, 256, 1),
        },
    }
    model['non_locality'] = {
        'grid_feature_mode': 'reweighting_without_mGGA_feats',
        'graph_readout': True,
        'decoder': {
            'spatial_feature_dim': 16,  # output spatial feature dimension
        },
        'reweighting': {  # ablate these
            'layers': 3,
            'hidden_dim': 16,
            'init_scale': 0.1,
            'output_activation_type': 'None',  # TODO: try scaled_sigmoid again
        },
    }


class EvaluationExperiment:
    # Logging and output
    output_dir: Path
    output_csv: Path
    metrics_summary_path: Path
    checkpoint_copy_path: Path

    # Physics model
    basis_name: str
    basis_derivative: int
    use_density_fitting: bool

    load_from: str | None
    checkpoint_prefix: str

    # Data
    dataset: LazyDEIXCDataset | DEIXCDataset
    dataloaders: dataloading.DataLoaders
    init_psys: systems.PreloadSystem
    target_functional: str
    data_split_seed: int
    data_split: Dict[str, Any]
    n_test_samples: int | None
    preload: Dict[str, bool]
    preload_center: bool
    shuffle: bool
    lazy: bool
    workers: int
    batch_size: int
    max_angular_momentum: int

    # Transforms
    preload_transform: Sequence[grain.Transformation]
    init_preloaded_basis_fns: discretization.PreloadedGTOBasis
    main_thread_transform: dataloading.ToJaxTransform

    # Configuration objects
    alignment: Alignment
    grid_level: int
    spin_restricted: bool

    # Initial guess
    base_initial_density_guess_key: BaseInitialGuess
    initial_ref_density_method_key: MethodKey | None
    initial_ref_density_method_kwargs: Dict[str, Any] | None
    min_ref_density_interpolation: float
    max_ref_density_interpolation: float
    noise_eps: float

    # Solver
    solver: DerivativeInformedSelfConsistentFieldSolver
    _solver_apply: Callable[..., Any]

    # Evaluation
    evaluation_splits: Tuple[str, ...]

    # Reference energy evaluation on model density
    evaluate_reference_on_model_density: bool
    _reference_energy_fn: Callable[..., Any] | None

    # SCF Acceleration benchmark
    scf_acceleration_enabled: bool
    _reference_scf_apply: Callable[..., Any] | None
    reference_scf_cycles: int

    @ex.capture(prefix='logging')  # type: ignore
    def __init__(
        self,
        overwrite: int,
        dir: str,
        output_csv: str | None,
        name: str,
    ) -> None:
        # Set up logging configuration
        self.output_dir = results_directory(dir, name, exists_ok=True)
        self.output_csv = results_metrics_csv_path(self.output_dir, output_csv)
        self.metrics_summary_path = results_summary_path(self.output_dir)
        # Initialize other components
        self.set_all()
        self.get_initial_density_matrix_fn()  # type: ignore
        self.init_solver()  # type: ignore
        self.init_reference_evaluator()  # type: ignore
        self.init_reference_scf_solver()  # type: ignore

    def set_all(self) -> None:
        """Initialize all independent class members."""
        self.set_base()  # type: ignore
        self.set_alignment()  # type: ignore
        self.set_quadrature()  # type: ignore
        self.set_basis()  # type: ignore
        self.set_evaluation_checkpoints()  # type: ignore

    @ex.capture(prefix='evaluation_checkpoints')  # type: ignore
    def set_evaluation_checkpoints(
        self, chkp_type: str, root_dir: str, name: str, indices: List[int]
    ) -> None:
        self.chkp_type = chkp_type
        self.chkp_root_dir = root_dir
        self.chkp_name = name
        self.chkp_indices = indices

    @ex.capture(prefix='base')  # type: ignore
    def set_base(
        self,
        use_density_fitting: bool,
        spin_restricted: bool,
        splits: Tuple[str, ...],
    ):
        self.spin_restricted = spin_restricted
        self.use_density_fitting = use_density_fitting

        self.density_mean_field_error_fn = get_density_mean_field_error_fn(
            self.spin_restricted,
            use_density_fitting,
            scale_per_electron=False,
        )

        # Coulomb energy error (density-only). No per-electron scaling here.
        self.coulomb_energy_error_fn = get_coulomb_energy_error_fn(
            self.spin_restricted,
            use_density_fitting,
            scale_per_electron=False,
        )

        self.evaluation_splits = splits

    @ex.capture(prefix='alignment')  # type: ignore
    def set_alignment(self, atom: int, basis: int, grid: int) -> None:
        self.alignment = Alignment(atom, basis, grid)

    @ex.capture(prefix='quadrature')  # type: ignore
    def set_quadrature(self, level: int) -> None:
        self.grid_level = level

    @ex.capture(prefix='basis')  # type: ignore
    def set_basis(self, name: str, derivative: int) -> None:
        self.basis_name = name
        self.basis_derivative = derivative

    @ex.capture(prefix='initial_density_guess')  # type: ignore
    def get_initial_density_matrix_fn(
        self,
        base_initial_density_guess_key: BaseInitialGuess,
        initial_ref_density_method: Dict[str, Any] | None,
        min_ref_density_interpolation: float,
        max_ref_density_interpolation: float,
        noise_eps: float,
    ) -> None:
        self.base_initial_density_guess_key = base_initial_density_guess_key

        if initial_ref_density_method is None:
            assert min_ref_density_interpolation == 0.0, (
                'min_ref_density_interpolation must be 0.0 when no reference density is provided'
            )
            assert max_ref_density_interpolation == 0.0, (
                'max_ref_density_interpolation must be 0.0 when no reference density is provided'
            )
            self.initial_ref_density_method_key = None
            self.initial_ref_density_method_kwargs = None
        else:
            self.initial_ref_density_method_key = initial_ref_density_method['key']
            self.initial_ref_density_method_kwargs = initial_ref_density_method['kwargs']

        self.min_ref_density_interpolation = min_ref_density_interpolation
        self.max_ref_density_interpolation = max_ref_density_interpolation
        self.noise_eps = noise_eps

        self.initial_density_matrix_fn = dataloading.get_initial_density_matrix_fn(
            self.min_ref_density_interpolation,
            self.max_ref_density_interpolation,
            self.noise_eps,
        )

        self._init_dataset()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataset(
        self,
        key: str,
        data_set_kwargs: Dict[str, Any],
        deixc_method_key: MethodKey,
        deixc_method_kwargs: Dict[str, Any],
        align_scf_trajectory: int,
        shift_dispersion: bool,
        data_split_seed: int,
        split: Dict[str, Any],
        n_test_samples: int | None,
        preload: Dict[str, bool],
        batch_size: int,
        shuffle: bool,
        lazy: bool,
        workers: int,
    ) -> None:
        self.data_split_seed = data_split_seed
        self.data_split = split
        self.n_test_samples = n_test_samples
        self.preload = preload
        self.preload_center = preload['center']
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.target_functional = deixc_method_kwargs['xc_str']
        self.lazy = lazy
        self.workers = workers

        dataset: dataloading.BaseDataset = dataloading.key_to_dataset[key.lower()](
            initial_ref_density_method_key=self.initial_ref_density_method_key,
            initial_ref_density_method_kwargs=self.initial_ref_density_method_kwargs,
            **data_set_kwargs,
        )

        if self.lazy:
            self.dataset = LazyDEIXCDataset(
                dataset,
                deixc_method_key,
                method_kwargs=deixc_method_kwargs,
                align_scf_trajectory=align_scf_trajectory,
                shift_dispersion=shift_dispersion,
            )
        else:
            self.dataset = DEIXCDataset(
                dataset,
                deixc_method_key,
                method_kwargs=deixc_method_kwargs,
                align_scf_trajectory=align_scf_trajectory,
                shift_dispersion=shift_dispersion,
            )
        self._init_dataloader()  # type: ignore
        self._init_main_thread_transform()  # type: ignore

    def _init_dataloader(self) -> None:
        """Construct sequential loaders for train/val/test splits."""
        dataset_ensemble = dataloading.DatasetEnsemble.infer_split(
            self.dataset, data_split_seed=self.data_split_seed, **self.data_split
        )
        # For eval-only runs, reuse test data in train/val slots so model init
        # and dataloader creation succeed even without aux data for those splits.
        if self.evaluation_splits == ('test',):
            dataset_ensemble = dataloading.DatasetEnsemble(
                train=dataset_ensemble.test,
                val=dataset_ensemble.test,
                test=dataset_ensemble.test,
            )

        self.max_angular_momentum, basis_fn_preloader = discretization.get_gto_preloader(
            self.basis_name, self.dataset.unique_elements
        )
        self.preload_transform = dataloading.get_preload_transform(
            self.batch_size,
            self.basis_name,
            self.spin_restricted,
            self.alignment,
            self.use_density_fitting,
            base_initial_density_guess=self.base_initial_density_guess_key,
            basis_fn_preloader=basis_fn_preloader,
            center=self.preload_center,
        )

        if self.lazy:
            self.init_psys, self.dataloaders = get_psys_and_lazy_dataloader(
                datasets=dataset_ensemble,
                transformations=self.preload_transform,
                shuffle=False,
                shuffling_seed=0,
                n_test_samples=self.n_test_samples,
            )
        else:
            self.init_psys, self.dataloaders = get_psys_and_dataloaders(
                datasets=dataset_ensemble,
                transformations=self.preload_transform,
                shuffle=False,
                shuffling_seed=0,
                workers=self.workers,
                worker_buffer_size=1,
                n_test_samples=self.n_test_samples,
            )
        self.init_preloaded_basis_fns = basis_fn_preloader(self.init_psys.atom_z)

    def _init_main_thread_transform(self) -> None:
        grid_fn = discretization.get_grid_fn(
            self.grid_level, self.dataset.unique_elements, self.alignment.grid
        )
        basis_fn = discretization.get_gto_grid_eval_fn(
            deriv=self.basis_derivative,
            max_angular_momentum=self.max_angular_momentum,
        )
        self.main_thread_transform = dataloading.get_jax_transform(
            grid_fn,
            basis_fn,
        )

    @ex.capture  # type: ignore
    def init_solver(self, solver: Dict[str, Any], model: Dict[str, Any]) -> None:
        model_module = self._xc_module()  # type: ignore
        solver_type = solver['method']
        solver_kwargs = solver['kwargs']
        if solver_type != 'scf':
            raise ValueError(f'Unsupported solver type for evaluation: {solver_type}')
        self.convergence_threshold = solver['convergence_threshold']
        self.solver = DerivativeInformedSelfConsistentFieldSolver(
            model_module,
            use_density_fitting=self.use_density_fitting,
            spin_restricted=self.spin_restricted,
            **solver_kwargs,
        )
        self._solver_apply = jax.jit(self.solver.apply)

    @ex.capture(prefix='reference_energy_on_model_density')  # type: ignore
    def init_reference_evaluator(self, enabled: bool) -> None:
        """Initialize the reference functional evaluator for computing energy on model density."""
        self.evaluate_reference_on_model_density = enabled
        if not enabled:
            self._reference_energy_fn = None
            return

        # Create reference XC module using the target functional (e.g., B3LYP)
        reference_functional = xc_energy.functionals.get_functional(
            self.target_functional,
            spin_restricted=self.spin_restricted,
            use_density_fitting=self.use_density_fitting,
        )
        requires_spin_resolved = getattr(
            reference_functional, 'requires_spin_resolved_features', False
        )
        reference_feature_fn = xc_energy.features.DensityFeatures(
            self.spin_restricted, spin_resolved=requires_spin_resolved
        )
        reference_xc_module = xc_energy.XCModule(
            reference_functional, reference_feature_fn
        )

        # Create a FockMatrix module for the reference functional
        from egxc.solver.fock import FockMatrix

        reference_fock_module = FockMatrix(
            reference_xc_module,
            use_density_fitting=self.use_density_fitting,
            spin_restricted=self.spin_restricted,
        )

        # Create a JIT-compiled function to compute reference energy on a given density matrix
        def ref_energy_fn(nuc_pos, density_matrix, sys):
            return reference_fock_module.apply(
                {'params': {}},
                nuc_pos,
                density_matrix,
                sys,
                method=reference_fock_module.energy,
            )

        self._reference_energy_fn = jax.jit(ref_energy_fn)

    @ex.capture(prefix='scf_acceleration')  # type: ignore
    def init_reference_scf_solver(
        self,
        enabled: bool,
        reference_scf_cycles: int,
    ) -> None:
        """Initialize reference functional SCF solver for warm-start acceleration benchmark.

        When enabled, this creates a separate SCF solver using the reference functional
        (e.g., B3LYP) that can be warm-started with the learned functional's converged
        density to measure SCF acceleration.
        """
        self.scf_acceleration_enabled = enabled
        self.reference_scf_cycles = reference_scf_cycles
        if not enabled:
            self._reference_scf_apply = None
            return

        # Create reference XC module (same pattern as init_reference_evaluator)
        reference_functional = xc_energy.functionals.get_functional(
            self.target_functional,
            spin_restricted=self.spin_restricted,
            use_density_fitting=self.use_density_fitting,
        )
        requires_spin_resolved = getattr(
            reference_functional, 'requires_spin_resolved_features', False
        )
        reference_feature_fn = xc_energy.features.DensityFeatures(
            self.spin_restricted, spin_resolved=requires_spin_resolved
        )
        reference_xc_module = xc_energy.XCModule(
            reference_functional, reference_feature_fn
        )

        # Create SCF solver for reference functional
        reference_scf_solver = DerivativeInformedSelfConsistentFieldSolver(
            reference_xc_module,
            use_density_fitting=self.use_density_fitting,
            spin_restricted=self.spin_restricted,
            cycles=reference_scf_cycles,
            convergence_acceleration_method='DIIS',
        )
        self._reference_scf_apply = jax.jit(reference_scf_solver.apply)

    @ex.capture  # type: ignore
    def _xc_module(self, model) -> xc_energy.XCModule:
        # Called by init_solver
        functional = xc_energy.functionals.get_functional(**model)
        # Some functionals (e.g., Skala) require spin-resolved features
        requires_spin_resolved = getattr(
            functional, 'requires_spin_resolved_features', False
        )
        feature_fn = xc_energy.features.DensityFeatures(
            self.spin_restricted, spin_resolved=requires_spin_resolved
        )
        module = xc_energy.XCModule(functional, feature_fn)
        return module

    def get_checkpoint_name(self, idx: int) -> str:
        return f'{self.chkp_name}_{idx}'

    def get_checkpoint_path(self, idx: int) -> str:
        checkpoint_dir = Path(self.chkp_root_dir) / self.get_checkpoint_name(idx)
        # Sort by name: split_0_<timestamp> > split_0, later timestamps sort last
        split_dirs = sorted(checkpoint_dir.glob('split_0*'))
        if not split_dirs:
            # Fall back to default if no split_0* directories exist yet
            return str(checkpoint_dir / 'split_0')
        return str(split_dirs[-1])

    def load_params(self, idx: int) -> Any:
        """Load checkpoint parameters from a checkpoint directory.

        Args:
            checkpoint_path: Absolute path to checkpoint directory
            checkpoint_type: Type of checkpoint to load (e.g., 'dynamic_train')

        Returns:
            Loaded parameters
        """
        checkpoint_file_path = Path(
            checkpoint_best_path(self.get_checkpoint_path(idx), self.chkp_type)
        )
        if not checkpoint_file_path.exists():
            raise FileNotFoundError(f'Checkpoint not found at {checkpoint_file_path}')
        with checkpoint_file_path.open('rb') as stream:
            params = pickle.load(stream)
        copy_directory_to_directory(
            self.get_checkpoint_path(idx),
            self.output_dir / self.get_checkpoint_name(idx) / 'checkpoint',
        )
        return params

    def __call__(self) -> None:
        """Evaluate one or more weight checkpoints of the same architecture and training stage.

        This method evaluates multiple checkpoints that differ only in their learned weights.
        All checkpoints must be from the same architecture and checkpoint type (training stage).
        Typical use case: comparing different hyperparameter settings or random seeds.

        For traditional functionals (e.g., SCAN, B3LYP), set evaluation_checkpoints.indices
        to None to run evaluation without loading checkpoints.

        Args:
            checkpoint_dir: Directory containing checkpoint subdirectories
            checkpoint_base: Base name for checkpoint subdirectories
            checkpoint_indices: List of checkpoint indices to evaluate (e.g., [1] for
                single checkpoint, [1, 4, 5] for multiple hyperparameter runs)
            checkpoint_type: Type of checkpoint - same for all (e.g., 'dynamic_train')
        """
        # Handle traditional functionals without checkpoints
        if self.chkp_indices is None:
            print(f'\n{"=" * 80}')
            print('Evaluating traditional functional (no checkpoints)')
            print(f'{"=" * 80}\n')

            try:
                # Traditional functionals have no learnable parameters
                params = {'params': {}}
                self.evaluate(params, self.evaluation_splits)
                print('\nCompleted evaluation')
                print(f'Results saved to: {self.output_csv}')
            except Exception as e:
                print(f'Error during evaluation: {e}')
                import traceback

                traceback.print_exc()

            print(f'\n{"=" * 80}')
            print(f'Evaluation complete. Results saved to: {self.output_dir}')
            print(f'{"=" * 80}')
            return

        print(f'\n{"=" * 80}')
        print(f'Evaluating {self.chkp_indices} checkpoint(s) of the same architecture')
        print(f'Checkpoint type: {self.chkp_type}')
        print(f'{"=" * 80}\n')

        for idx in self.chkp_indices:
            print(f'\n{"=" * 80}')
            print(
                f'Evaluating weights from training run {idx}: {self.get_checkpoint_name(idx)}'
            )
            print(f'{"=" * 80}\n')

            # Create subdirectory for this checkpoint's results
            checkpoint_output_dir = self.output_dir / self.get_checkpoint_name(idx)
            checkpoint_output_dir.mkdir(parents=True, exist_ok=True)

            # Temporarily update output paths for this checkpoint
            original_output_csv = self.output_csv
            original_metrics_summary = self.metrics_summary_path
            self.output_csv = checkpoint_output_dir / original_output_csv.name
            self.metrics_summary_path = (
                checkpoint_output_dir / original_metrics_summary.name
            )

            try:
                # Load weights for this checkpoint
                params = self.load_params(idx)

                # Evaluate with these weights
                self.evaluate(params, self.evaluation_splits)

                print(f'\nCompleted evaluation of training run {idx}')
                print(f'Results saved to: {self.output_csv}')

            except Exception as e:
                print(f'Error evaluating training run {idx}: {e}')
                import traceback

                traceback.print_exc()

            finally:
                # Restore original output paths
                self.output_csv = original_output_csv
                self.metrics_summary_path = original_metrics_summary

        print(f'\n{"=" * 80}')
        print(f'Checkpoint evaluation complete. Results saved to: {self.output_dir}')
        print(f'{"=" * 80}')

    def evaluate(
        self,
        params: Any,
        splits: Sequence[str],
    ) -> None:
        """Iterate through the requested dataset splits and collect metrics."""
        records = []
        for split in splits:
            loader = getattr(self.dataloaders, split)
            for psys, basis_fns, targets in loader:
                record = self._evaluate_sample(
                    params,
                    split,
                    psys,
                    basis_fns,
                    targets,
                )
                records.append(record)

        self._write_csv(records)

    def _evaluate_sample(
        self,
        params: Any,
        split: str,
        psys: systems.PreloadSystem,
        basis_fns,
        targets: DEIXCTargets,
    ) -> Dict[str, Any]:
        """Run SCF, compute evaluation metrics, and return a CSV row for one sample."""
        density_matrices, sys = self.main_thread_transform(psys, basis_fns)
        initial_dm, _ = self.initial_density_matrix_fn(density_matrices, None)

        (e_hj, e_xc), (C_pred, P_pred, F_pred, V_xc_pred) = self._solver_apply(
            params, initial_dm, sys
        )
        total_energies = e_xc + e_hj + systems.nuclear_energy_fn(sys._nuc_pos, sys)
        delta_total_pred = jnp.abs(total_energies[1:] - total_energies[:-1])
        converged_mask = delta_total_pred < self.convergence_threshold
        cycles_to_convergence = jnp.where(
            jnp.any(converged_mask),
            jnp.argmax(converged_mask) + 1,
            total_energies.shape[0],
        )  # 1-based index of first converged step; max cycles if never converged
        ref_cycles_cold = targets.cycles_to_convergence(self.convergence_threshold)
        relative_cycle_count = cycles_to_convergence / ref_cycles_cold  # type: ignore
        did_converge = bool(jnp.any(converged_mask))
        total_pred = total_energies[-1]
        mf_mae = self.density_mean_field_error_fn(sys, targets.density_matrix, P_pred[-1])
        coulomb_mae = self.coulomb_energy_error_fn(
            sys, targets.density_matrix, P_pred[-1]
        )

        # Dipole difference using existing implementation (same-basis case here).
        dipole_error = dipole_difference(
            targets.density_matrix, P_pred[-1], sys.grid, True
        )
        gap_pred = systems.homo_lumo_gap_fn(
            F_pred[-1],
            sys.fock_tensors.diagonal_overlap,
            int(sys.n_electrons),
            self.spin_restricted,
        )
        # Grid-based density difference norms (same-basis evaluation)
        delta_field = delta_field_fn(
            targets.density_matrix, P_pred[-1], sys.grid.aos, reference_basis_is_same=True
        )
        density_L1 = field_integral_measures(delta_field, sys.grid.weights, norm='L1')
        density_L2 = field_integral_measures(delta_field, sys.grid.weights, norm='L2')
        record = {
            'split': split,
            'idx': int(psys.idx),
            'did_converge': did_converge,
            'total_energy_mae_mEh': float(
                jnp.abs(total_pred - targets.total_energy) * 1e3
            ),
            'cycles_to_convergence': int(cycles_to_convergence),  # type: ignore
            'relative_cycle_count': float(relative_cycle_count),
            'xc_energy_mae_mEh': float(jnp.abs(e_xc[-1] - targets.xc_energy) * 1e3),
            'mean_field_mae_mEh': float(mf_mae * 1e3),
            'coulomb_energy_mae_mEh': float(coulomb_mae * 1e3),
            'dipole_difference_au': float(dipole_error),
            'homo_lumo_gap_mae_mEh': float(
                jnp.abs(gap_pred - targets.homo_lumo_gap) * 1e3
            ),
            'density_L1': float(density_L1),
            'density_L2': float(density_L2),
            # Grid-free density-difference surrogates
            'overlap_based_mae_surrogate': float(
                overlap_based_mae_surrogate(
                    targets.density_matrix, P_pred[-1], sys.fock_tensors.overlap
                )
            ),
            'overlap_based_mse_surrogate': float(
                overlap_based_mse_surrogate(
                    targets.density_matrix, P_pred[-1], sys.fock_tensors.overlap
                )
            ),
        }

        # Evaluate reference functional energy on model's final density if enabled
        if self.evaluate_reference_on_model_density:
            (ref_e_hj, ref_e_xc) = self._reference_energy_fn(  # type: ignore
                sys._nuc_pos,
                P_pred[-1],
                sys,
            )
            ref_total_on_model_dm = (
                ref_e_hj + ref_e_xc + systems.nuclear_energy_fn(sys._nuc_pos, sys)
            )
            record['ref_total_energy_on_model_dm_mEh'] = float(
                jnp.abs(ref_total_on_model_dm - targets.total_energy) * 1e3
            )
            record['ref_xc_energy_on_model_dm_mEh'] = float(
                jnp.abs(ref_e_xc - targets.xc_energy) * 1e3
            )

        # SCF Acceleration Benchmark: warm-start reference with learned density
        if self.scf_acceleration_enabled:
            # Run reference functional SCF from learned density (warm-start)
            (ref_e_hj_warm, ref_e_xc_warm), (_, P_ref_warm, _, _) = (
                self._reference_scf_apply(  # type: ignore
                    {'params': {}},  # Reference functional has no learnable params
                    P_pred[-1],  # Warm-start from learned functional's converged density
                    sys,
                )
            )
            ref_total_energies_warm = (
                ref_e_xc_warm
                + ref_e_hj_warm
                + systems.nuclear_energy_fn(sys._nuc_pos, sys)
            )

            # Count cycles for warm-start convergence
            ref_delta_warm = jnp.abs(
                ref_total_energies_warm[1:] - ref_total_energies_warm[:-1]
            )
            ref_converged_warm = ref_delta_warm < self.convergence_threshold
            ref_cycles_warm = jnp.where(
                jnp.any(ref_converged_warm),
                jnp.argmax(ref_converged_warm) + 1,
                ref_total_energies_warm.shape[0],
            )
            ref_did_converge_warm = bool(jnp.any(ref_converged_warm))

            # Compute acceleration metrics (ref_cycles_cold already computed above)
            # RIC = Relative Iteration Count (warm/cold), lower is better
            warm_start_ric = float(ref_cycles_warm / ref_cycles_cold)  # type: ignore

            # Add to record
            record['ref_cycles_cold'] = int(ref_cycles_cold)
            record['ref_cycles_warm'] = int(ref_cycles_warm)  # type: ignore
            record['ref_did_converge_warm'] = ref_did_converge_warm
            record['warm_start_ric'] = warm_start_ric
            # Energy error after 1 cycle of warm-started reference SCF
            record['ref_energy_after_1_cycle_mEh'] = float(
                jnp.abs(ref_total_energies_warm[0] - targets.total_energy) * 1e3
            )

        return record

    def _write_csv(self, records: Iterable[Dict[str, Any]]) -> None:
        fieldnames = [
            'split',
            'idx',
            'did_converge',
            'cycles_to_convergence',
            'relative_cycle_count',
            'total_energy_mae_mEh',
            'xc_energy_mae_mEh',
            'mean_field_mae_mEh',
            'coulomb_energy_mae_mEh',
            'dipole_difference_au',
            'homo_lumo_gap_mae_mEh',
            'density_L1',
            'density_L2',
            'overlap_based_mae_surrogate',
            'overlap_based_mse_surrogate',
        ]
        # Add reference energy on model density columns if enabled
        if self.evaluate_reference_on_model_density:
            fieldnames.extend(
                [
                    'ref_total_energy_on_model_dm_mEh',
                    'ref_xc_energy_on_model_dm_mEh',
                ]
            )
        # Add SCF acceleration benchmark columns if enabled
        if self.scf_acceleration_enabled:
            fieldnames.extend(
                [
                    'ref_cycles_cold',
                    'ref_cycles_warm',
                    'ref_did_converge_warm',
                    'warm_start_ric',
                    'ref_energy_after_1_cycle_mEh',
                ]
            )
        records = list(records)
        df = pd.DataFrame.from_records(records, columns=fieldnames)
        df.to_csv(self.output_csv, index=False)
        metric_columns = fieldnames[3:]  # Skip 'split', 'idx', and 'did_converge'
        # Note: energy metrics are in mEh; dipole is in atomic units.
        summary_lines = ['Dataset mean metrics (mEh; dipole in a.u.):']
        # Convergence statistics
        total_samples = len(df)
        converged_samples = df['did_converge'].sum()
        summary_lines.append(f'  Total samples: {total_samples}')
        summary_lines.append(f'  Converged samples: {converged_samples}')
        summary_lines.append('')
        # Metric means
        means = df[metric_columns].mean()
        for key in metric_columns:
            summary_lines.append(f'  {key}: {means[key]:.6f}')  # type: ignore

        print(summary_lines[0])
        for line in summary_lines[1:]:
            print(line)

        self.metrics_summary_path.write_text('\n'.join(summary_lines) + '\n')


@ex.automain
def main(overwrite: int):
    experiment = EvaluationExperiment(overwrite)  # type: ignore
    experiment()
