import optax

from egxc.dataloading import (
    DataLoaders,
    InitialDensityMatrixFn,
    ToJaxTransform,
)
from egxc.solver.base import Solver
from egxc.training.loss import LossConfig
from egxc.training.step_fn import get_eval_step, get_loss_fn, get_train_step
from egxc.training.utils import ema
from egxc.training.utils.early_stopping import EarlyStopping
from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.logging import Logger
from egxc.utils.typing import NnParams, PRNGKey


def run(
    init_params: NnParams,
    model: Solver,
    optimizer: optax.GradientTransformation | optax.GradientTransformationExtraArgs,
    ema_decay: float,
    early_stopping_patience: int,
    early_stopping_min_relative_improvement: float,
    loss_config: LossConfig,
    epochs: int,
    dataloaders: DataLoaders,
    input_transform: ToJaxTransform,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    logger: Logger,
    checkpointer: CheckpointManager | None,
    prng_key: PRNGKey,
) -> NnParams:
    loss_fn = get_loss_fn(loss_config, model)
    step_fn = get_train_step(
        loss_fn,
        initial_density_matrix_fn,
        optimizer,
        ema_decay,
        loss_config.max_energy_volatility,
    )
    eval_step = get_eval_step(loss_fn, initial_density_matrix_fn, input_transform)

    params = init_params

    initial_optax_state = optimizer.init(params)
    params_ema = ema.EMA.create(params)
    opt_state = (initial_optax_state, params_ema)
    early_stopping = EarlyStopping(
        early_stopping_patience,
        early_stopping_min_relative_improvement,
        checkpointer,
        'pretrain',
    )
    for e in range(epochs):
        logger.start_epoch(e, 'pretrain')
        logger.start_mean(['pretrain/train/energy error [mEh]'])
        logger.benchmark_start('pretrain')
        for psys, pvec_basis_fns, targets in dataloaders.train:
            logger.benchmark('data_loader')
            P0, sys = input_transform(psys, pvec_basis_fns)
            logger.benchmark('input_transform')
            (
                prng_key,
                params,
                opt_state,
                loss,
                e_pred,
                e_xc,
                dm_pred,
                volatility,
                grad_norm,
                update_norm,
            ) = step_fn(prng_key, params, opt_state, targets, P0, sys)
            logger.benchmark('step_fn')
            logger.egxc_loss(
                loss.item(),
                e_pred[-1].item(),
                e_xc[-1].item(),
                targets.energy.item(),
                'pretrain/train',
                psys.idx,
                volatility.item(),
                loss_config.max_energy_volatility,
            )
            logger.updates(grad_norm.item(), update_norm.item(), 'pretrain/train')
            logger.benchmark('logging')

        logger.stop_mean()
        logger.epoch_training_duration(len(dataloaders.train), 'pretrain')

        logger.start_mean(['pretrain/val/energy error [mEh]', 'pretrain/val/loss'])
        eval_params = ema.value(opt_state[1])
        for psys, pvec_basis_fns, targets in dataloaders.val:
            loss, (e_pred, e_xc, _, volatility) = eval_step(
                eval_params, psys, pvec_basis_fns, targets
            )
            logger.egxc_loss(
                loss.item(),
                e_pred[-1].item(),
                e_xc[-1].item(),
                targets.energy.item(),
                'pretrain/val',
                psys.idx,
                volatility.item(),
                loss_config.max_energy_volatility,
            )

        mean_val_loss = logger.get_current_mean('pretrain/val/loss')
        logger.stop_mean()
        if early_stopping.stop(mean_val_loss, eval_params):
            break

    return early_stopping.best_params
