from typing import Callable, Tuple

import jax
import jax.numpy as jnp
import optax

from egxc.dataloading import (
    DensityMatrices,
    InitialDensityMatrixFn,
    Targets,
    ToJaxTransform,
)
from egxc.discretization import PreloadedGTOBasis
from egxc.solver.base import Solver
from egxc.systems import PreloadSystem, System, nuclear_energy_fn
from egxc.training.loss import LossConfig, get_loss_fns
from egxc.training.utils import ema
from egxc.utils.typing import (
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatSCF,
    FloatSCFx2xBxB,
    FloatSCFxBxB,
    NnParams,
    PRNGKey,
)

LossFn = Callable[
    [NnParams, Targets, FloatBxB | Float2xBxB, System],
    Tuple[Float1, Tuple[FloatSCF, FloatSCF, FloatSCFxBxB | FloatSCFx2xBxB]],
]


def _create_zero_updates(params):
    """Create zero-valued updates matching the structure of params."""
    return jax.tree.map(lambda x: jax.numpy.zeros_like(x), params)


@jax.jit
def energy_volatility(e_scf: FloatSCF) -> Float1:
    """
    Calculates the absolute energy difference
    between the last two SCF steps in mEh.
    """
    return abs(e_scf[-2] - e_scf[-1]) * 1e3


def get_loss_fn(loss_config: LossConfig, model: Solver) -> LossFn:
    loss_fns = get_loss_fns(loss_config)

    @jax.jit
    def loss_fn(params, targets: Targets, P0: FloatBxB | Float2xBxB, sys: System):
        (e_hj, e_xc), P_predicted = model.apply(params, P0, sys)
        e_predicted = e_xc + e_hj + nuclear_energy_fn(sys._nuc_pos, sys)
        loss = loss_fns.energy(targets.energy, e_predicted, sys.n_electrons)  # type: ignore
        loss += loss_fns.density(
            targets.density_matrix,  # type: ignore
            P_predicted,  # type: ignore
            sys,
        )
        return loss, (e_predicted, e_xc, P_predicted)

    return loss_fn


TrainStepFn = Callable[
    [PRNGKey, NnParams, Tuple[optax.OptState, ema.EMA], Targets, DensityMatrices, System],
    Tuple[
        PRNGKey,
        NnParams,
        Tuple[optax.OptState, ema.EMA],
        Float1,
        FloatSCF,
        FloatSCF,
        FloatSCFxBxB | FloatSCFx2xBxB,
        Float1,
        Float1,
        Float1,
    ],
]


def get_train_step(
    loss_fn: LossFn,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    optimizer: optax.GradientTransformation | optax.GradientTransformationExtraArgs,
    ema_decay: float,
    max_energy_volatility: float = jnp.inf,
) -> TrainStepFn:
    @jax.jit
    def step_fn(
        key: PRNGKey,
        params: NnParams,
        opt_state: Tuple[optax.OptState, ema.EMA],
        targets: Targets,
        P0: DensityMatrices,
        sys: System,
    ) -> Tuple[
        PRNGKey,
        NnParams,
        Tuple[optax.OptState, ema.EMA],
        Float1,
        FloatSCF,
        FloatSCF,
        FloatSCFxBxB | FloatSCFx2xBxB,
        Float1,
        Float1,
        Float1,
    ]:
        P0, key = initial_density_matrix_fn(P0, key)  # type: ignore
        (loss, (e_predicted, e_xc, dm_pred)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(params, targets, P0, sys)
        optax_state, params_ema = opt_state
        volatility = energy_volatility(e_predicted)
        # updates, optax_state = optimizer.update(grads, optax_state, params, value=loss)  # type: ignore # TODO: why do I need to pass params here?
        updates, optax_state = jax.lax.cond(
            volatility > max_energy_volatility,
            lambda: (_create_zero_updates(params), optax_state),  # Skip updates
            lambda: optimizer.update(grads, optax_state, params, value=loss),  # type: ignore # Apply updates
        )
        params = optax.apply_updates(params, updates)
        params_ema = ema.update(params_ema, params, ema_decay)
        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)
        return (
            key,
            params,
            (optax_state, params_ema),  # type: ignore
            loss,
            e_predicted,
            e_xc,
            dm_pred,
            volatility,
            grad_norm,
            update_norm,
        )

    return step_fn


EvalStepFn = Callable[
    [NnParams, PreloadSystem, PreloadedGTOBasis, Targets],
    Tuple[Float1, Tuple[FloatSCF, FloatSCF, FloatSCFxBxB | FloatSCFx2xBxB, Float1]],
]


def get_eval_step(
    loss_fn: LossFn,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    input_transform: ToJaxTransform,
) -> EvalStepFn:
    def eval_step(
        params, psys: PreloadSystem, pvec_basis_fns: PreloadedGTOBasis, targets: Targets
    ):
        init_densities, sys = input_transform(psys, pvec_basis_fns)
        P0, _ = initial_density_matrix_fn(init_densities, None)
        loss, (e_predicted, e_xc, dm_pred) = loss_fn(params, targets, P0, sys)
        volatility = energy_volatility(e_predicted)
        return loss, (e_predicted, e_xc, dm_pred, volatility)

    return eval_step
