from typing import Callable, Literal, Tuple

import jax
import jax.numpy as jnp
import numpy as onp

from deixc.scf import DerivativeInformedSolver
from deixc.training.loss import DynamicLossConfig, StaticLossConfig
from deixc.training.utils import dynamic_reference_step, dynamic_step, static_step
from egxc.dataloading import DataLoaders, InitialDensityMatrixFn, ToJaxTransform
from egxc.training.optimizer import OptConfig, get_optimizer
from egxc.training.utils import ema, metropolis
from egxc.training.utils.early_stopping import EarlyStopping
from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.logging import Logger
from egxc.utils.typing import NnParams, PRNGKey
from egxc.xc_energy import XCModule

Mode = Tuple[Literal['static', 'dynamic'], Literal['pretrain', 'train']]


def final_evaluation(
    params: NnParams,
    solver: DerivativeInformedSolver,
    loss_config: DynamicLossConfig,
    dataloaders: DataLoaders,
    main_thread_transform: ToJaxTransform,
    logger: Logger,
    initial_density_matrix_fn: InitialDensityMatrixFn,
) -> None:
    print('#' * 40, 'Final Evaluation')

    loss_fn = dynamic_step.get_model_loss_fn(loss_config, solver)
    eval_step = dynamic_step.get_eval_step_fn(
        'final_evaluation',
        loss_fn,
        initial_density_matrix_fn,
        main_thread_transform,
        loss_config.weights,
        loss_config.scf_loss_vectorization.scf_decay_weights,
        logger,
        max_energy_volatility=onp.inf,
    )
    logger.write_csv = True

    logger.start_mean(
        [
            f'final_evaluation/{prefix} energy error [mEh]'
            for prefix in ['train', 'val', 'test']
        ]
    )
    for split in ['train', 'val', 'test']:
        loader = getattr(dataloaders, split)
        print('#' * 20, f'On {split.capitalize()} Set')
        for psys, pvec_basis_fns, targets in loader:
            eval_step(params, psys, pvec_basis_fns, targets)
        del loader
    logger.stop_mean()


def isolate_xc_module(
    solver: DerivativeInformedSolver, local_only: bool = False
) -> XCModule:
    if local_only and hasattr(solver.xc_module.functional, 'local_model'):
        model = XCModule(
            solver.xc_module.functional.local_model, solver.xc_module.feature_fn
        )
    else:
        model = XCModule(solver.xc_module.functional, solver.xc_module.feature_fn)
    return model


def run(
    mode: Tuple[Literal['static', 'dynamic'], Literal['pretrain', 'train']],
    params: NnParams,
    params_init_fn: Callable[[int], NnParams],
    solver: DerivativeInformedSolver,
    opt_config: OptConfig,
    loss_config: StaticLossConfig | DynamicLossConfig,
    dataloaders: DataLoaders,
    main_thread_transform: ToJaxTransform,
    logger: Logger,
    checkpointer: CheckpointManager,
    benchmark: bool,
    prng_key: PRNGKey,
    initial_density_matrix_fn: InitialDensityMatrixFn | None = None,
    max_energy_volatility: float = jnp.inf,
    ref_functional: XCModule | None = None,
) -> NnParams:
    # triviality skip
    if opt_config.epochs == 0:
        return params

    is_static = mode[0] == 'static'
    is_pretrain = mode[1] == 'pretrain'
    logging_prefix = f'{mode[0]}_{mode[1]}'
    xc_module = isolate_xc_module(solver, local_only=is_pretrain)
    if opt_config.decay_only_graph_readout and not is_pretrain:
        assert xc_module.functional.is_graph_based
        graph_readout_decay_mask = xc_module.functional.graph_readout_decay_mask(params)
    else:
        graph_readout_decay_mask = None
    optimizer, gradient_post_processing = get_optimizer(
        opt_config, manual_scaling=True, graph_readout_decay_mask=graph_readout_decay_mask
    )

    if is_static:
        assert isinstance(loss_config, StaticLossConfig)
        loss_fn = static_step.get_model_loss_fn(loss_config, xc_module)
        train_step = static_step.get_train_step_fn(
            logging_prefix,
            loss_fn,
            main_thread_transform,
            optimizer,
            gradient_post_processing,
            opt_config.schedule.base_rate,
            opt_config.ema_decay,
            loss_config.weights,
            loss_config.scf_loss_vectorization.scf_decay_weights,
            logger,
        )
        eval_step = static_step.get_eval_step_fn(
            logging_prefix,
            loss_fn,
            main_thread_transform,
            loss_config.weights,
            loss_config.scf_loss_vectorization.scf_decay_weights,
            logger,
        )
    else:
        assert isinstance(loss_config, DynamicLossConfig)
        assert initial_density_matrix_fn is not None
        if loss_config.with_dynamic_reference:
            assert ref_functional is not None
            loss_fn = dynamic_reference_step.get_model_loss_fn(
                loss_config, solver, ref_functional
            )
            train_step = dynamic_reference_step.get_train_step_fn(
                logging_prefix,
                loss_fn,
                main_thread_transform,
                optimizer,
                gradient_post_processing,
                opt_config.schedule.base_rate,
                opt_config.ema_decay,
                loss_config.weights,
                loss_config.scf_loss_vectorization.scf_decay_weights,
                logger,
                initial_density_matrix_fn,
                max_energy_volatility,
            )
            prng_key, split_key = jax.random.split(prng_key)
            eval_step = dynamic_reference_step.get_eval_step_fn(
                logging_prefix,
                loss_fn,
                initial_density_matrix_fn,
                main_thread_transform,
                loss_config.weights,
                loss_config.scf_loss_vectorization.scf_decay_weights,
                logger,
                max_energy_volatility=onp.inf,
                prng_key=split_key,
            )
        else:
            loss_fn = dynamic_step.get_model_loss_fn(loss_config, solver)
            train_step = dynamic_step.get_train_step_fn(
                logging_prefix,
                loss_fn,
                main_thread_transform,
                optimizer,
                gradient_post_processing,
                opt_config.schedule.base_rate,
                opt_config.ema_decay,
                loss_config.weights,
                loss_config.scf_loss_vectorization.scf_decay_weights,
                logger,
                initial_density_matrix_fn,
                max_energy_volatility,
            )
            eval_step = dynamic_step.get_eval_step_fn(
                logging_prefix,
                loss_fn,
                initial_density_matrix_fn,
                main_thread_transform,
                loss_config.weights,
                loss_config.scf_loss_vectorization.scf_decay_weights,
                logger,
                max_energy_volatility=onp.inf,
            )

    opt_state = (
        optimizer.init(params),
        gradient_post_processing.init(params),
        ema.EMA.create(params),
    )
    early_stopping = EarlyStopping(
        opt_config.early_stopping_patience,
        opt_config.early_stopping_min_relative_improvement,
        checkpointer,
        logging_prefix,
    )

    metropolis_stabilizer = metropolis.MetropolisTrainingStabilizer.create_from_dict(
        params, params_init_fn, opt_state, opt_config.metropolis_stabilizer
    )

    for e in range(opt_config.epochs):
        # Optional hot-restarting of the optimizer
        if opt_config.with_restarts and e in opt_config.restart_epochs:  # type: ignore
            params = ema.value(opt_state[-1])
            new_opt_config = opt_config.create_restart_config()
            return run(
                mode,
                params,
                params_init_fn,
                solver,
                new_opt_config,
                loss_config,
                dataloaders,
                main_thread_transform,
                logger,
                checkpointer,
                benchmark,
                prng_key,
                initial_density_matrix_fn,
                max_energy_volatility,
                ref_functional,
            )
        logger.start_epoch(e, logging_prefix)
        logger.start_mean(
            [
                f'{logging_prefix}/train/loss',
                f'{logging_prefix}/train/total energy error [mEh]',
                f'{logging_prefix}/train/xc energy error [mEh]',
            ]
        )
        logger.benchmark_start(f'{logging_prefix}') if benchmark else None
        for psys, pvec_basis_fns, targets in dataloaders.train:
            logger.benchmark('data_loader') if benchmark else None
            params, opt_state, prng_key, grad_norm, update_norm = train_step(
                params, opt_state, psys, pvec_basis_fns, targets, prng_key
            )
            logger.updates(
                grad_norm.item(), update_norm.item(), f'{logging_prefix}/train'
            )
            logger.benchmark('step_fn') if benchmark else None
        mean_train_loss = logger.get_current_mean(
            f'{logging_prefix}/train/loss',
            max_nans=50,  # stabilizer doesn't need NaN guarding
        )
        logger.stop_mean()
        accept, params, opt_state, prng_key = metropolis_stabilizer.propose_update(
            mean_train_loss, params, prng_key, opt_state
        )
        if not accept:
            logger.log({f'{logging_prefix}/train/rejected_update': mean_train_loss})
            if early_stopping.increment_patience_counter():
                break
            continue  # no need to validate with identical parameters
        logger.start_mean(
            [
                f'{logging_prefix}/val/loss',
                f'{logging_prefix}/val/total energy error [mEh]',
                f'{logging_prefix}/val/xc energy error [mEh]',
            ]
        )
        eval_params = ema.value(opt_state[-1])
        for psys, pvec_basis_fns, targets in dataloaders.val:
            eval_step(
                eval_params,
                psys,
                pvec_basis_fns,
                targets,
            )
        mean_val_loss = logger.get_current_mean(f'{logging_prefix}/val/loss')
        logger.stop_mean()
        if early_stopping.stop(mean_val_loss, eval_params):
            break

    jax.clear_caches()

    return early_stopping.best_params
