"""
This module contains the loss function and train/eval step functions for the static training stage.
In this stage, the model is trained to MATCH a set of targets along a precomputed SCF trajectory.
"""

from functools import partial
from typing import Protocol, Tuple

import jax
import jax.numpy as jnp
import optax
from jax import lax

from deixc.dataset import DEIXCTargets
from deixc.training.loss import (
    RelativeDEILossWeights,
    SCFDecayWeights,
    StaticLossConfig,
    get_loss_fns,
)
from deixc.training.utils.header import (
    EvalStepFn,
    LossComponents,
    TrainStepFn,
)
from egxc.dataloading import ToJaxTransform
from egxc.discretization import PreloadedGTOBasis
from egxc.systems.base import Grid, PreloadSystem, System
from egxc.training.optimizer import Optimizer
from egxc.training.utils import ema
from egxc.utils.logging import Logger
from egxc.utils.typing import (
    PRECISION,
    BoolB,
    Float1,
    Float2xBxB,
    FloatBxB,
    NnParams,
    PRNGKey,
)
from egxc.xc_energy import XCModule
from egxc.xc_energy.functionals.classical.hybrid import Hybrid


class StaticLossFn(Protocol):
    def __call__(
        self,
        params: NnParams,
        targets: DEIXCTargets,
        sys: System,
        n_occ: int,
        rel_loss_w: RelativeDEILossWeights,
        scf_decay_weights: SCFDecayWeights,
    ) -> Tuple[
        Float1,
        Tuple[
            Float1,
            LossComponents,
        ],
    ]: ...


def get_model_loss_fn(
    config: StaticLossConfig,
    model: XCModule,
) -> StaticLossFn:
    """Return a loss function that does not recompile for runtime weight / LR changes.

    Notes
    -----
    * Only `n_occ` is static.
    * Vectorization mode ("only_final" vs SCF-wide) and model traits (hybrid/graph-based)
      remain static by design. Changing those will recompile.
    """

    loss_fns = get_loss_fns(config, train_type='static')
    is_vec = config.scf_loss_vectorization.is_vectorized

    @partial(jax.jit, static_argnames=('n_occ',))
    def loss_fn(
        params: NnParams,
        targets: DEIXCTargets,
        sys: System,
        n_occ: int,
        rel_loss_w: RelativeDEILossWeights,
        scf_decay_weights: SCFDecayWeights,
    ):
        ZERO = jnp.array(0.0, dtype=PRECISION.loss)
        # Set up non_local_kwargs for graph-based and hybrid functionals
        non_local_kwargs = {}
        if isinstance(model.functional, Hybrid):
            non_local_kwargs['eri_tensor'] = sys.fock_tensors.ert
        if model.functional.is_graph_based:
            non_local_kwargs['atom_mask'] = sys.atom_mask
            non_local_kwargs['nuc_pos'] = sys._nuc_pos
            non_local_kwargs['grid_coords'] = sys.grid.coords

        # ------------------------------ Energy term -----------------------------
        @jax.vmap
        def _xc_energy_fn(
            P: FloatBxB | Float2xBxB,  # vec: FloatRefSCFxBxB | FloatRefSCFx2xBxB
        ) -> Float1:
            return model.apply(
                params,
                P,
                sys.grid,
                **non_local_kwargs,
            )  # type: ignore

        if is_vec.xc_energy:
            dm_e = targets.density_matrices
            target_xc_energies = targets.xc_energies
        else:
            dm_e = targets.density_matrix[None, ...]
            target_xc_energies = targets.xc_energy[None, ...]

        e_xc_pred = _xc_energy_fn(dm_e)

        L_xc_energy = loss_fns.energy(
            target_xc_energies,
            e_xc_pred,
            sys.n_electrons,
            scf_decay_weights.xc_energy,
        )
        final_predicted_xc_energy = e_xc_pred[-1]

        # ------------------------------ Forces term -----------------------------
        L_forces: Float1 = ZERO  # TODO

        # --------- XC potential (+ optional linear response) & direction --------
        # Computational flow:
        # if linear response is enabled
        #   A) we compute the linear response jointly with the XC potential
        # else, if either xc potential or orbital rotation direction is enabled
        #   B) we compute the XC potential
        # else, we don't compute the XC potential or linear response
        #
        # This is done to avoid recomputing the XC potential and linear response for each SCF cycle
        # if only one of the two is enabled.

        # Switches (JAX scalars)
        include_xc_pot = jnp.greater(rel_loss_w.xc_potential, 0.0)
        include_min_dir = jnp.greater(rel_loss_w.orbital_rotation_gradient, 0.0)
        require_vec_xc_pot = is_vec.xc_potential or is_vec.orbital_rotation_gradient

        @partial(jax.vmap, in_axes=(None, 0, None, None))
        def _xc_pot_fn(
            params_: NnParams,
            P0: FloatBxB | Float2xBxB,
            grid: Grid,
            basis_mask: BoolB,
        ) -> FloatBxB | Float2xBxB:
            return model.apply(
                params_,
                P0,
                grid,
                basis_mask,
                method=model.xc_potential,
                **non_local_kwargs,
            )  # type: ignore

        def _orb_rot_grad_loss(operands):
            V_targ, V_pred, C = operands
            _is_vec = is_vec.orbital_rotation_gradient
            return loss_fns.orbital_rotation_gradient(
                V_targ if _is_vec else V_targ[-1][None, ...],
                V_pred if _is_vec else V_pred[-1][None, ...],
                C if _is_vec else C[-1][None, ...],
                n_occ,
                scf_decay_weights.orbital_rotation_gradient,
                targets.orbital_energies,
            )

        def _xc_potential_loss(operands):
            xc_pot_target, xc_pot_pred = operands
            _is_vec = is_vec.xc_potential
            return loss_fns.xc_potential(
                xc_pot_target if _is_vec else xc_pot_target[-1][None, ...],
                xc_pot_pred if _is_vec else xc_pot_pred[-1][None, ...],
                sys,
                scf_decay_weights.xc_potential,
            )

        if require_vec_xc_pot:
            dm = targets.density_matrices
        else:
            dm = targets.density_matrix[None, ...]
        xc_pot_pred = _xc_pot_fn(params, dm, sys.grid, sys.fock_tensors.basis_mask)

        L_xc_pot = lax.cond(
            include_xc_pot,
            _xc_potential_loss,
            lambda _: ZERO,
            operand=xc_pot_pred,
        )
        L_orb_rot_grad = lax.cond(
            include_min_dir,
            _orb_rot_grad_loss,
            lambda _: ZERO,
            operand=(
                targets.xc_potential_matrices,
                xc_pot_pred,
                targets.mo_coeffs,
            ),
        )

        # ----------------------------- Sum everything -----------------------------
        loss: Float1 = (
            rel_loss_w.xc_energy * L_xc_energy
            + rel_loss_w.forces * L_forces
            + rel_loss_w.xc_potential * L_xc_pot
            + rel_loss_w.orbital_rotation_gradient * L_orb_rot_grad
        )

        return loss, (
            final_predicted_xc_energy,
            LossComponents(
                L_xc_energy,
                L_forces,
                L_xc_pot,
                L_orb_rot_grad,
                ZERO,
                ZERO,
                ZERO,
            ),
        )

    return loss_fn


# ------------------------------- Train step fn -------------------------------


def get_train_step_fn(
    logging_prefix: str,
    loss_fn: StaticLossFn,
    main_thread_transform: ToJaxTransform,
    optimizer: Optimizer,
    gradient_post_processing: Optimizer,
    learning_rate: float,
    ema_decay: float,
    relative_loss_weights: RelativeDEILossWeights,
    scf_decay_weights: SCFDecayWeights,
    logger: Logger,
) -> TrainStepFn:
    @partial(
        jax.jit,
        static_argnames=('n_occ',),
        donate_argnames=('params', 'state'),
    )
    def jitted_step_fn(
        params: NnParams,
        state: Tuple[optax.OptState, optax.OptState, ema.EMA],
        targets: DEIXCTargets,
        sys: System,
        n_occ: int,
        lr: Float1,
        loss_w: RelativeDEILossWeights,
        scf_w: SCFDecayWeights,
    ):
        (loss, (e_xc, loss_components)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(
            params,
            targets,
            sys,
            n_occ,
            loss_w,
            scf_w,
        )

        opt_state, post_processing_state, ema_state = state

        updates, opt_state = optimizer.update(grads, opt_state, params, value=loss)  # type: ignore
        updates = jax.tree.map(lambda g: lr * g, updates)
        updates, post_processing_state = gradient_post_processing.update(
            updates, post_processing_state, params
        )
        params = optax.apply_updates(params, updates)
        ema_state = ema.update(ema_state, params, ema_decay)

        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)

        return (
            params,
            (opt_state, post_processing_state, ema_state),
            loss,
            e_xc,
            loss_components,
            grad_norm,
            update_norm,
        )

    def step_fn(
        params: NnParams,
        opt_state: Tuple[optax.OptState, optax.OptState, ema.EMA],
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
        prng_key: PRNGKey
        | None,  # unused_prng_key to enforce same signature as dynamic step
    ) -> Tuple[
        NnParams,
        Tuple[optax.OptState, optax.OptState, ema.EMA],
        PRNGKey | None,
        Float1,
        Float1,
    ]:
        _, sys = main_thread_transform(psys, preloaded_basis_fns)
        (
            params,
            opt_state,
            loss,
            e_xc,
            loss_components,
            grad_norm,
            update_norm,
        ) = jitted_step_fn(
            params,
            opt_state,
            targets,
            sys,
            targets.n_occ,
            jnp.asarray(learning_rate),
            relative_loss_weights,
            scf_decay_weights,
        )
        logger.deixc_loss(
            loss.item(),
            None,
            e_xc.item(),
            targets.total_energy.item(),
            targets.xc_energy.item(),
            loss_components.to_host(),
            relative_loss_weights.to_host(),  # type: ignore
            f'{logging_prefix}/train',
            psys.idx,
            volatility=None,
            max_energy_volatility=None,
        )
        return params, opt_state, prng_key, grad_norm, update_norm

    return step_fn


# -------------------------------- Eval step fn --------------------------------
def get_eval_step_fn(
    logging_prefix: str,
    loss_fn: StaticLossFn,
    main_thread_transform: ToJaxTransform,
    relative_loss_weights: RelativeDEILossWeights,
    scf_decay_weights: SCFDecayWeights,
    logger: Logger,
) -> EvalStepFn:
    def step_fn(
        params: NnParams,
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
    ):
        _, sys = main_thread_transform(psys, preloaded_basis_fns)
        loss, (e_xc, loss_components) = loss_fn(
            params,
            targets,
            sys,
            targets.n_occ,
            relative_loss_weights,
            scf_decay_weights,
        )
        logger.deixc_loss(
            loss.item(),
            None,
            e_xc.item(),
            targets.total_energy.item(),
            targets.xc_energy.item(),
            loss_components.to_host(),  # type: ignore
            relative_loss_weights.to_host(),  # type: ignore
            f'{logging_prefix}/val',
            psys.idx,
            volatility=None,
            max_energy_volatility=None,
        )

    return step_fn
