"""
This module contains the loss function and train/eval step functions for the dynamic training stage.
In this stage, the model is trained to CONVERGE an SCF calculation to the desired fixed point.
"""

from functools import partial
from typing import Protocol, Tuple

import jax
import jax.numpy as jnp
import optax
from jax import lax

from deixc.dataset import DEIXCTargets
from deixc.scf import DerivativeInformedSolver
from deixc.training.loss import (
    DynamicLossConfig,
    RelativeDEILossWeights,
    SCFDecayWeights,
    get_loss_fns,
)
from deixc.training.utils.header import (
    EvalStepFn,
    LossComponents,
    TrainStepFn,
)
from egxc.dataloading import (
    DensityMatrices,
    InitialDensityMatrixFn,
    ToJaxTransform,
)
from egxc.discretization import PreloadedGTOBasis
from egxc.systems import (
    PreloadSystem,
    System,
    homo_lumo_gap_fn,
    nuclear_energy_fn,
)
from egxc.training.optimizer import Optimizer
from egxc.training.utils import ema
from egxc.utils.logging import Logger
from egxc.utils.typing import (
    PRECISION,
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatSCF,
    NnParams,
    PRNGKey,
)
from egxc.xc_energy.functionals.classical.hybrid import Hybrid


class DynamicLossFn(Protocol):
    def __call__(
        self,
        params: NnParams,
        targets: DEIXCTargets,
        P0: FloatBxB | Float2xBxB,
        sys: System,
        occ_virtual_shape: Tuple[int, int],
        rel_loss_w: RelativeDEILossWeights,
        scf_decay_weights: SCFDecayWeights,
    ) -> Tuple[
        Float1,
        Tuple[
            FloatSCF,
            Float1,
            FloatBxB | Float2xBxB,
            LossComponents,
        ],
    ]: ...


def _create_zero_updates(params):
    """Create zero-valued updates matching the structure of params."""
    return jax.tree.map(lambda x: jax.numpy.zeros_like(x), params)


@jax.jit
def energy_volatility(e_scf: FloatSCF) -> Float1:
    """Absolute change between last two SCF steps in mEh."""
    return abs(e_scf[-2] - e_scf[-1]) * 1e3


def get_model_loss_fn(
    config: DynamicLossConfig, model: DerivativeInformedSolver
) -> DynamicLossFn:
    loss_fns = get_loss_fns(config, train_type='dynamic')
    is_vec = config.scf_loss_vectorization.is_vectorized

    @partial(jax.jit, static_argnames=('occ_virtual_shape',))
    def loss_fn(
        params,
        targets: DEIXCTargets,
        P0: FloatBxB | Float2xBxB,
        sys: System,
        occ_virtual_shape: Tuple[int, int],
        rel_loss_w: RelativeDEILossWeights,
        scf_decay_weights: SCFDecayWeights,
    ):
        ZERO = jnp.array(0.0, dtype=PRECISION.loss)

        (e_hj, e_xc), (C_pred, dm_pred, F_pred, vxc_pred) = model.apply(params, P0, sys)
        # Set up non_local_kwargs for graph-based and hybrid functionals
        non_local_kwargs = {}
        if isinstance(model.xc_module.functional, Hybrid):
            non_local_kwargs['eri_tensor'] = sys.fock_tensors.ert
        if model.xc_module.functional.is_graph_based:
            non_local_kwargs['atom_mask'] = sys.atom_mask
            non_local_kwargs['nuc_pos'] = sys._nuc_pos
            non_local_kwargs['grid_coords'] = sys.grid.coords

        # ------------------------------ Energy terms -----------------------------
        predicted_total_energies = e_xc + e_hj + nuclear_energy_fn(sys._nuc_pos, sys)
        final_predicted_xc_energy = e_xc[-1]

        L_xc_energy = loss_fns.energy(
            targets.xc_energy, e_xc, sys.n_electrons, scf_decay_weights.xc_energy
        )

        L_total_energy = loss_fns.energy(
            targets.total_energy,
            predicted_total_energies,
            sys.n_electrons,
            scf_decay_weights.total_energy,
        )
        # ------------------------------ Forces term -----------------------------
        L_forces: Float1 = ZERO  # TODO

        # ------------------------------ Density term -----------------------------
        include_density = jnp.greater(rel_loss_w.density, 0.0)

        density_target = (
            targets.density_matrices
            if is_vec.density
            else targets.density_matrices[-1][None, ...]
        )

        def _density_loss(dm_pred: FloatBxB | Float2xBxB) -> Float1:
            return loss_fns.density(
                density_target,
                dm_pred if is_vec.density else dm_pred[-1][None, ...],
                sys,
                scf_decay_weights.density,
            )

        L_density = lax.cond(
            include_density, _density_loss, lambda _: ZERO, operand=dm_pred
        )

        # ------------------------------ XC potential term ------------------------
        include_xc_pot = jnp.greater(rel_loss_w.xc_potential, 0.0)

        xc_pot_target = (
            targets.xc_potential_matrices
            if is_vec.xc_potential
            else targets.xc_potential_matrices[-1][None, ...]
        )

        def _xc_potential_loss(v_xc: FloatBxB | Float2xBxB) -> Float1:
            return loss_fns.xc_potential(
                xc_pot_target,
                v_xc if is_vec.xc_potential else v_xc[-1][None, ...],
                sys,
                scf_decay_weights.xc_potential,
            )

        L_xc_pot = lax.cond(
            include_xc_pot, _xc_potential_loss, lambda _: ZERO, operand=vxc_pred
        )

        # ------------------------------ Orbital rotation direction term ------------------------
        include_min_dir = jnp.greater(rel_loss_w.orbital_rotation_gradient, 0.0)

        def _includes_orbital_rotation_gradient(v_xc) -> Float1:
            # Compute orbital rotation gradient directions from XC potential matrices
            orb_rot_dir_target = targets.get_orbital_rotation_gradient(
                targets.xc_potential_matrices
            )
            if not is_vec.orbital_rotation_gradient:
                orb_rot_dir_target = orb_rot_dir_target[-1][None, ...]

            mo_coeffs_target = (
                targets.mo_coeffs
                if is_vec.orbital_rotation_gradient
                else targets.mo_coeffs[-1][None, ...]
            )

            return loss_fns.orbital_rotation_gradient(
                orb_rot_dir_target,
                v_xc if is_vec.orbital_rotation_gradient else v_xc[-1][None, ...],
                mo_coeffs_target,
                occ_virtual_shape[0],
                scf_decay_weights.orbital_rotation_gradient,
                targets.orbital_energies,
            )

        L_orbital_rotation_gradient = lax.cond(
            include_min_dir,
            _includes_orbital_rotation_gradient,
            lambda _: ZERO,
            operand=vxc_pred,
        )

        # ----------------------------- Sum everything -----------------------------
        loss: Float1 = (
            rel_loss_w.xc_energy * L_xc_energy
            + rel_loss_w.total_energy * L_total_energy
            + rel_loss_w.forces * L_forces
            + rel_loss_w.density * L_density
            + rel_loss_w.xc_potential * L_xc_pot
            + rel_loss_w.orbital_rotation_gradient * L_orbital_rotation_gradient
        )

        return loss, (
            predicted_total_energies,
            final_predicted_xc_energy,
            F_pred[-1],
            LossComponents(
                L_xc_energy,
                L_forces,
                L_xc_pot,
                L_orbital_rotation_gradient,
                ZERO,
                L_total_energy,
                L_density,
            ),
        )

    return loss_fn


def get_train_step_fn(
    logging_prefix: str,
    loss_fn: DynamicLossFn,
    main_thread_transform: ToJaxTransform,
    optimizer: Optimizer,
    gradient_post_processing: Optimizer,
    learning_rate: float,
    ema_decay: float,
    relative_loss_weights: RelativeDEILossWeights,
    scf_decay_weights: SCFDecayWeights,
    logger: Logger,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    max_energy_volatility: float,
) -> TrainStepFn:
    @partial(
        jax.jit,
        static_argnames=('occ_virtual_shape',),
        donate_argnames=('params', 'state'),
    )
    def jit_step(
        params: NnParams,
        state: Tuple[optax.OptState, optax.OptState, ema.EMA],
        targets: DEIXCTargets,
        sys: System,
        occ_virtual_shape: Tuple[int, int],
        lr: Float1,
        loss_w: RelativeDEILossWeights,
        scf_w: SCFDecayWeights,
        P0: DensityMatrices,
        prng_key: PRNGKey | None,
    ):
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        P0, prng_key = initial_density_matrix_fn(P0, prng_key)  # type: ignore
        (loss, (pred_e_tot, pred_final_e_xc, _, loss_components)), grads = grad_fn(
            params,
            targets,
            P0,
            sys,
            occ_virtual_shape,
            loss_w,
            scf_w,
        )
        opt_state, post_processing_state, ema_state = state
        volatility = energy_volatility(pred_e_tot)
        updates, opt_state = jax.lax.cond(
            volatility > max_energy_volatility,
            lambda: (_create_zero_updates(params), opt_state),
            lambda: optimizer.update(grads, opt_state, params, value=loss),  # type: ignore
        )
        updates = jax.tree.map(lambda g: lr * g, updates)
        updates, post_processing_state = gradient_post_processing.update(
            updates, post_processing_state, params
        )
        params = optax.apply_updates(params, updates)
        ema_state = ema.update(ema_state, params, ema_decay)
        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)
        return (
            params,
            (opt_state, post_processing_state, ema_state),  # type: ignore
            loss,
            pred_e_tot[-1],
            pred_final_e_xc,
            volatility,
            grad_norm,
            update_norm,
            loss_components,
            prng_key,
        )

    def step_fn(
        params: NnParams,
        opt_state: Tuple[optax.OptState, optax.OptState, ema.EMA],
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
        prng_key: PRNGKey | None,
    ) -> Tuple[
        NnParams,
        Tuple[optax.OptState, optax.OptState, ema.EMA],
        PRNGKey | None,
        Float1,
        Float1,
    ]:
        P0, sys = main_thread_transform(psys, preloaded_basis_fns)
        (
            params,
            opt_state,
            loss,
            pred_final_e_tot,
            pred_final_e_xc,
            volatility,
            grad_norm,
            update_norm,
            loss_components,
            prng_key,
        ) = jit_step(
            params,
            opt_state,
            targets,
            sys,
            targets.occupied_virtual_shape,
            jnp.asarray(learning_rate),
            relative_loss_weights,
            scf_decay_weights,
            P0,
            prng_key,
        )
        logger.deixc_loss(
            loss.item(),
            pred_final_e_tot.item(),
            pred_final_e_xc.item(),
            targets.total_energy.item(),
            targets.xc_energy.item(),
            loss_components.to_host(),
            relative_loss_weights.to_host(),  # type: ignore
            f'{logging_prefix}/train',
            psys.idx,
            volatility.item(),
            max_energy_volatility,
        )
        return params, opt_state, prng_key, grad_norm, update_norm

    return step_fn


def get_eval_step_fn(
    logging_prefix: str,
    loss_fn: DynamicLossFn,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    input_transform: ToJaxTransform,
    relative_loss_weights: RelativeDEILossWeights,
    scf_decay_weights: SCFDecayWeights,
    logger: Logger,
    max_energy_volatility: float,
) -> EvalStepFn:
    def step_fn(
        params,
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
    ):
        init_densities, sys = input_transform(psys, preloaded_basis_fns)
        P0, _ = initial_density_matrix_fn(init_densities, None)
        loss, (pred_e_tot, pred_final_e_xc, F_pred, loss_components) = loss_fn(
            params,
            targets,
            P0,
            sys,
            targets.occupied_virtual_shape,
            relative_loss_weights,
            scf_decay_weights,
        )
        volatility = energy_volatility(pred_e_tot)
        homo_lumo_gap = homo_lumo_gap_fn(
            F_pred, sys.fock_tensors.diagonal_overlap, int(sys.n_electrons)
        )
        mae_homo_lumo_gap = abs(homo_lumo_gap - targets.homo_lumo_gap).item()
        logger.log(
            {f'observables/{logging_prefix}/val/homo_lumo_gap_mae': mae_homo_lumo_gap}
        )
        logger.deixc_loss(
            loss.item(),
            pred_e_tot[-1].item(),
            pred_final_e_xc.item(),
            targets.total_energy.item(),
            targets.xc_energy.item(),
            loss_components.to_host(),  # type: ignore
            relative_loss_weights.to_host(),  # type: ignore
            f'{logging_prefix}/val',
            psys.idx,
            volatility.item(),
            max_energy_volatility,
        )

    return step_fn
