import jax
import numpy as onp

from egxc.dataloading import DataLoaders
from egxc.training.step_fn import EvalStepFn
from egxc.utils.logging import Logger
from egxc.utils.typing import NnParams


def run_final_evaluation(
    params: NnParams,
    eval_step: EvalStepFn,
    dataloaders: DataLoaders,
    logger: Logger,
) -> None:
    jax.clear_caches()
    print('#' * 40, 'Final Evaluation')
    logger.write_csv = True

    logger.start_mean(
        [f'final {prefix} energy error [mEh]' for prefix in ['train', 'val', 'test']]
    )
    print('#' * 20, 'On Training Set')
    for psys, pvec_basis_fns, targets in dataloaders.train:
        loss, (e_pred, e_xc, _, volatility) = eval_step(
            params, psys, pvec_basis_fns, targets
        )
        logger.egxc_loss(
            loss.item(),
            e_pred[-1].item(),
            e_xc[-1].item(),
            targets.energy.item(),
            'final train',
            psys.idx,
            volatility.item(),
            max_energy_volatility=onp.inf,
        )
    del dataloaders.train

    print('#' * 20, 'On Valiation Set')
    for psys, pvec_basis_fns, targets in dataloaders.val:
        loss, (e_pred, e_xc, _, volatility) = eval_step(
            params, psys, pvec_basis_fns, targets
        )
        logger.egxc_loss(
            loss.item(),
            e_pred[-1].item(),
            e_xc[-1].item(),
            targets.energy.item(),
            'final val',
            psys.idx,
            volatility.item(),
            max_energy_volatility=onp.inf,
        )
    del dataloaders.val

    print('#' * 20, 'On Test Set')
    for psys, pvec_basis_fns, targets in dataloaders.test:
        loss, (e_pred, e_xc, _, volatility) = eval_step(
            params, psys, pvec_basis_fns, targets
        )
        logger.egxc_loss(
            loss.item(),
            e_pred[-1].item(),
            e_xc[-1].item(),
            targets.energy.item(),
            'final test',
            psys.idx,
            volatility.item(),
            max_energy_volatility=onp.inf,
        )
    logger.stop_mean()
