import optax

from egxc.dataloading import (
    DataLoaders,
    InitialDensityMatrixFn,
    ToJaxTransform,
)
from egxc.solver.base import Solver  # , PulayForceWrapper
from egxc.training.evaluate import run_final_evaluation
from egxc.training.loss import LossConfig
from egxc.training.step_fn import (
    get_eval_step,
    get_loss_fn,
    get_train_step,
)
from egxc.training.utils import ema
from egxc.training.utils.early_stopping import EarlyStopping
from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.logging import Logger
from egxc.utils.typing import NnParams, PRNGKey


def run(
    init_params: NnParams,
    model: Solver,
    optimizer: optax.GradientTransformation | optax.GradientTransformationExtraArgs,
    ema_decay: float,
    early_stopping_patience: int,
    early_stopping_min_relative_improvement: float,
    loss_config: LossConfig,
    epochs: int,
    dataloaders: DataLoaders,
    input_transform: ToJaxTransform,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    logger: Logger,
    checkpointer: CheckpointManager | None,
    test: bool,
    prng_key: PRNGKey,
) -> None:
    loss_fn = get_loss_fn(loss_config, model)
    step_fn = get_train_step(
        loss_fn,
        initial_density_matrix_fn,
        optimizer,
        ema_decay,
        loss_config.max_energy_volatility,
    )
    eval_step = get_eval_step(loss_fn, initial_density_matrix_fn, input_transform)

    params = init_params
    optax_state = optimizer.init(params)
    params_ema = ema.EMA.create(params)
    opt_state = (optax_state, params_ema)
    early_stopping = EarlyStopping(
        early_stopping_patience,
        early_stopping_min_relative_improvement,
        checkpointer,
        'train',
    )

    for e in range(epochs):
        logger.start_epoch(e, 'train')
        logger.start_mean(['train/energy error [mEh]'])
        logger.benchmark_start('train')
        for psys, pvec_basis_fns, targets in dataloaders.train:
            logger.benchmark('data_loader')
            P0, sys = input_transform(psys, pvec_basis_fns)
            logger.benchmark('input_transform')
            (
                prng_key,
                params,
                opt_state,
                loss,
                e_pred,
                e_xc,
                dm_pred,
                volatility,
                grad_norm,
                update_norm,
            ) = step_fn(prng_key, params, opt_state, targets, P0, sys)
            logger.benchmark('step_fn')
            logger.egxc_loss(
                loss.item(),
                e_pred[-1].item(),
                e_xc[-1].item(),
                targets.energy.item(),
                'train',
                psys.idx,
                volatility.item(),
                loss_config.max_energy_volatility,
            )
            logger.updates(grad_norm.item(), update_norm.item(), 'train')
            logger.benchmark('logging')

        logger.stop_mean()
        logger.epoch_training_duration(len(dataloaders.train), 'train')

        if e == epochs - 1 and test:
            # skip last validation
            print('#' * 20, 'skipping last validation')
            continue

        logger.start_mean(['val/loss', 'val/energy error [mEh]'])
        eval_params = ema.value(opt_state[1])
        for psys, pvec_basis_fns, targets in dataloaders.val:
            loss, (e_pred, e_xc, _, volatility) = eval_step(
                eval_params, psys, pvec_basis_fns, targets
            )
            logger.egxc_loss(
                loss.item(),
                e_pred[-1].item(),
                e_xc[-1].item(),
                targets.energy.item(),
                'val',
                psys.idx,
                volatility.item(),
                loss_config.max_energy_volatility,
            )

        mean_val_loss = logger.get_current_mean('val/loss')

        if early_stopping.stop(mean_val_loss, eval_params):
            logger.stop_mean()
            break
        else:
            logger.stop_mean()

    if test:
        if epochs > 0:
            final_params = early_stopping.best_params
        else:
            final_params = params
        run_final_evaluation(final_params, eval_step, dataloaders, logger)
