"""
This module contains the loss function and train/eval step functions for the dynamic reference
training stage. It adapts the dynamic training stage to a distillation setting. Here, the model is
trained to CONVERGE an SCF calculation. The training is supervised by evaluating the targeted
reference functional on the fly.
"""

from functools import partial
from typing import Protocol, Tuple

import jax
import jax.numpy as jnp
import optax
from jax import lax

from deixc.dataset import DEIXCTargets
from deixc.scf import DerivativeInformedSolver
from deixc.training.loss import (
    DynamicLossConfig,
    RelativeDEILossWeights,
    SCFDecayWeights,
    get_loss_fns,
)
from deixc.training.utils.header import (
    EvalStepFn,
    LossComponents,
    TrainStepFn,
)
from egxc.dataloading import (
    DensityMatrices,
    InitialDensityMatrixFn,
    ToJaxTransform,
)
from egxc.discretization import PreloadedGTOBasis
from egxc.systems import PreloadSystem, System, homo_lumo_gap_fn, nuclear_energy_fn
from egxc.training.optimizer import Optimizer
from egxc.training.utils import ema
from egxc.utils.logging import Logger
from egxc.utils.typing import (
    PRECISION,
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatSCF,
    FloatSCFx2xBxB,
    FloatSCFxBxB,
    NnParams,
    PRNGKey,
)
from egxc.xc_energy import XCModule
from egxc.xc_energy.functionals.classical.hybrid import Hybrid


class DynamicLossFn(Protocol):
    def __call__(
        self,
        params: NnParams,
        targets: DEIXCTargets,
        P0: FloatBxB | Float2xBxB,
        sys: System,
        occ_virtual_shape: Tuple[int, int],
        rel_loss_w: RelativeDEILossWeights,
        scf_decay_weights: SCFDecayWeights,
        consumable_prng_key: PRNGKey,
    ) -> Tuple[
        Float1,
        Tuple[
            FloatSCF,
            Float1,
            FloatBxB | Float2xBxB,
            LossComponents,
        ],
    ]: ...


def _create_zero_updates(params):
    """Create zero-valued updates matching the structure of params."""
    return jax.tree.map(lambda x: jax.numpy.zeros_like(x), params)


@jax.jit
def energy_volatility(e_scf: FloatSCF) -> Float1:
    """Absolute change between last two SCF steps in mEh."""
    return abs(e_scf[-2] - e_scf[-1]) * 1e3


def get_model_loss_fn(
    config: DynamicLossConfig,
    model: DerivativeInformedSolver,
    reference_functional: XCModule,
) -> DynamicLossFn:
    """
    Returns a loss function that computes the loss of the model with respect to the reference functional.

    The energy and density targets are not dynamically recomputed, as these are the observables
    of the desired fixed point.

    The reference functional is used to compute the xc potential, to teach the model
    how to converge to the desired fixed point by matching the reference xc potential.

    """
    loss_fns = get_loss_fns(config, train_type='dynamic_with_reference')

    @partial(jax.jit, static_argnames=('occ_virtual_shape',))
    def loss_fn(
        params,
        static_targets: DEIXCTargets,
        P0: FloatBxB | Float2xBxB,
        sys: System,
        occ_virtual_shape: Tuple[int, int],
        rel_loss_w: RelativeDEILossWeights,
        scf_decay_weights: SCFDecayWeights,
        prng_key: PRNGKey,
    ):
        ZERO = jnp.array(0.0, dtype=PRECISION.loss)
        ZERO_FN = lambda _: ZERO
        NO_FREE_PARAMS = {}
        (e_mf, e_xc), (coeffs, dms, pred_F, vxc) = model.apply(params, P0, sys)
        e_nuclei = nuclear_energy_fn(sys._nuc_pos, sys)
        e_tot = e_mf + e_xc + e_nuclei

        # Set up non_local_kwargs for graph-based and hybrid functionals
        non_local_kwargs = {}
        if isinstance(model.xc_module.functional, Hybrid):
            non_local_kwargs['eri_tensor'] = sys.fock_tensors.ert
        if model.xc_module.functional.is_graph_based:
            non_local_kwargs['atom_mask'] = sys.atom_mask
            non_local_kwargs['nuc_pos'] = sys._nuc_pos
            non_local_kwargs['grid_coords'] = sys.grid.coords

        ref_non_local_kwargs = {}
        if isinstance(reference_functional.functional, Hybrid):
            ref_non_local_kwargs['eri_tensor'] = sys.fock_tensors.ert

        def ref_xc_energy_and_potential_fn(
            _dms: FloatSCFxBxB | FloatSCFx2xBxB,
            _sys: System,
        ) -> Tuple[FloatSCF, FloatSCFxBxB | FloatSCFx2xBxB]:
            def single(_dm: FloatBxB | Float2xBxB):
                ref_e_xc, ref_vxc = reference_functional.apply(
                    NO_FREE_PARAMS,
                    jax.lax.stop_gradient(_dm),  # treat reference as a fixed target
                    _sys.grid,
                    _sys.fock_tensors.basis_mask,
                    **ref_non_local_kwargs,
                    method=reference_functional.xc_energy_and_potential,
                )
                return ref_e_xc, ref_vxc

            return jax.vmap(single)(_dms)  # type: ignore

        target_e_xc, target_vxc = ref_xc_energy_and_potential_fn(dms, sys)  # type: ignore

        # ------------------------------ Energy terms -----------------------------
        # xc energy loss
        L_xc_energy = loss_fns.energy(
            target_e_xc,
            e_xc,
            sys.n_electrons,
            scf_decay_weights.xc_energy,
        )
        # When evaluating the total energy on the reference functional the only energy difference is the xc energy.
        # Hence constraining the total energy in this sense once more does not add any new information,
        # Instead we constrain the final total energy.
        L_final_energy = loss_fns.energy(
            static_targets.total_energy * jnp.ones_like(e_tot),
            e_tot,
            sys.n_electrons,
            scf_decay_weights.total_energy,
        )

        # ------------------------------ Forces term -----------------------------
        L_forces: Float1 = ZERO  # TODO

        # ------------------------------ Density term -----------------------------
        if config.density.measure == 'mean_field':
            # The density error induced energy error can be isolated by constraining the mean field energy.
            target_e_mf = (
                static_targets.total_energy - e_nuclei - static_targets.xc_energy
            )
            L_density = loss_fns.energy(
                target_e_mf * jnp.ones_like(e_mf),
                e_mf,
                sys.n_electrons,
                scf_decay_weights.density,
            )
        else:
            L_density = loss_fns.density(
                static_targets.density_matrix,
                dms,  # type: ignore
                sys,
                scf_decay_weights.density,
            )

        # ------------------------------ XC potential term ------------------------
        L_xc_pot = loss_fns.xc_potential(
            target_vxc,  # type: ignore
            vxc,  # type: ignore
            sys,
            scf_decay_weights.xc_potential,
        )
        # ------------------------------ Orbital rotation direction term ------------------------
        include_rot_dir = jnp.greater(rel_loss_w.orbital_rotation_gradient, 0.0)

        def _includes_direct_minimization_direction(operands) -> Float1:
            """
            The direction-matching term is computed by comparing the direct XC-energy minimization
            directions of the density in occupied-virtual space evaluated along the predicted
            SCF trajectory.
            """
            _target_vxc, _vxc, _mo_coeffs = operands
            # NOTE: we observed best performance when not stopping the gradient w.r.t. the MO coefficients
            # _mo_coeffs = jax.lax.stop_gradient(_mo_coeffs)
            out = loss_fns.orbital_rotation_gradient(
                _target_vxc,
                _vxc,
                _mo_coeffs,
                occ_virtual_shape[0],
                scf_decay_weights.orbital_rotation_gradient,
                static_targets.orbital_energies,
            )
            return out / sys.n_electrons

        L_orb_rot_grad = lax.cond(
            include_rot_dir,
            _includes_direct_minimization_direction,
            ZERO_FN,
            operand=(target_vxc, vxc, coeffs),
        )

        # --------------------- XC potential linear response (optional) -----------------
        include_orb_rot_hessian = jnp.greater(rel_loss_w.orbital_rotation_hessian, 0.0)

        def _compute_orbital_rotation_hessian_loss() -> Float1:
            """
            We compute the linear response of the XC potential w.r.t to a set of random perturbations.
            By dynamically evaluating these HVPs at the currently learned ground state density
            we effectively teach the functional to learn both the correct stability behavior and
            how to achieve it if the learned ground state is currently wrong.
            By differentiating through the predicted ground state, once can actively
            nudge the model towards the correct ground state. Alternatively one can stop the gradient
            to only learn the correct linear response around the current ground state.
            """

            T = config.orbital_rotation_hessian.n_perturbations

            if T == 0:
                return ZERO

            predicted_mo_coeffs: FloatBxB | Float2xBxB = coeffs[-1]  # type: ignore

            if not config.orbital_rotation_hessian.differentiate_through_ground_state:
                # optionally stop the gradient w.r.t. the predicted ground state
                # predicted_ground_state = jax.lax.stop_gradient(predicted_ground_state)
                predicted_mo_coeffs = jax.lax.stop_gradient(predicted_mo_coeffs)

            perturbations = jax.random.normal(
                prng_key, (T, *occ_virtual_shape), dtype=PRECISION.loss
            )

            predicted_linear_responses = model.apply(
                params,
                predicted_mo_coeffs,
                perturbations,
                sys.grid,
                sys.fock_tensors.basis_mask,
                sys.fock_tensors.occupancies,
                occ_virtual_shape,
                method=DerivativeInformedSolver.xc_rotation_hvp,
                **non_local_kwargs,
            )

            target_linear_responses = reference_functional.apply(
                NO_FREE_PARAMS,
                predicted_mo_coeffs,
                perturbations,
                sys.grid,
                sys.fock_tensors.basis_mask,
                sys.fock_tensors.occupancies,
                occ_virtual_shape,
                method=reference_functional.xc_rotation_hvp,
                **ref_non_local_kwargs,
            )

            _loss = loss_fns.orbital_rotation_hessian(
                target_linear_responses,  # type: ignore
                predicted_linear_responses,  # type: ignore
                static_targets.orbital_energies,
            )

            return _loss

        L_orb_rot_hessian = lax.cond(
            include_orb_rot_hessian, _compute_orbital_rotation_hessian_loss, lambda: ZERO
        )

        # ----------------------------- Sum everything -----------------------------
        loss: Float1 = (
            rel_loss_w.xc_energy * L_xc_energy
            + rel_loss_w.total_energy * L_final_energy
            + rel_loss_w.forces * L_forces
            + rel_loss_w.density * L_density
            + rel_loss_w.xc_potential * L_xc_pot
            + rel_loss_w.orbital_rotation_gradient * L_orb_rot_grad
            + rel_loss_w.orbital_rotation_hessian * L_orb_rot_hessian
        )

        return loss, (
            e_tot,
            e_xc[-1],
            pred_F[-1],
            LossComponents(
                L_xc_energy,
                L_forces,
                L_xc_pot,
                L_orb_rot_grad,
                L_orb_rot_hessian,
                L_final_energy,
                L_density,
            ),
        )

    return loss_fn


def get_train_step_fn(
    logging_prefix: str,
    loss_fn: DynamicLossFn,
    main_thread_transform: ToJaxTransform,
    optimizer: Optimizer,
    gradient_post_processing: Optimizer,
    learning_rate: float,
    ema_decay: float,
    relative_loss_weights: RelativeDEILossWeights,
    scf_decay_weights: SCFDecayWeights,
    logger: Logger,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    max_energy_volatility: float,
) -> TrainStepFn:
    @partial(
        jax.jit,
        static_argnames=('occ_virtual_shape',),
    )
    def jit_step(
        params: NnParams,
        state: Tuple[optax.OptState, optax.OptState, ema.EMA],
        targets: DEIXCTargets,
        sys: System,
        occ_virtual_shape: Tuple[int, int],
        lr: Float1,
        loss_w: RelativeDEILossWeights,
        scf_w: SCFDecayWeights,
        P0: DensityMatrices,
        prng_key: PRNGKey,
    ):
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

        P0, prng_key = initial_density_matrix_fn(P0, prng_key)  # type: ignore
        prng_key, split_key = jax.random.split(prng_key)
        (loss, (pred_e_tot, pred_final_e_xc, _, loss_components)), grads = grad_fn(
            params,
            targets,
            P0,
            sys,
            occ_virtual_shape,
            loss_w,
            scf_w,
            split_key,
        )
        opt_state, post_processing_state, ema_state = state
        volatility = energy_volatility(pred_e_tot)
        updates, opt_state = jax.lax.cond(
            volatility > max_energy_volatility,
            lambda: (_create_zero_updates(params), opt_state),
            lambda: optimizer.update(grads, opt_state, params, value=loss),  # type: ignore
        )
        updates = jax.tree.map(lambda g: lr * g, updates)
        updates, post_processing_state = gradient_post_processing.update(
            updates, post_processing_state, params
        )
        params = optax.apply_updates(params, updates)
        ema_state = ema.update(ema_state, params, ema_decay)
        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)
        return (
            params,
            (opt_state, post_processing_state, ema_state),  # type: ignore
            loss,
            pred_e_tot[-1],
            pred_final_e_xc,
            volatility,
            grad_norm,
            update_norm,
            loss_components,
            prng_key,
        )

    def step_fn(
        params: NnParams,
        opt_state: Tuple[optax.OptState, optax.OptState, ema.EMA],
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
        prng_key: PRNGKey,
    ) -> Tuple[
        NnParams,
        Tuple[optax.OptState, optax.OptState, ema.EMA],
        PRNGKey,
        Float1,
        Float1,
    ]:
        P0, sys = main_thread_transform(psys, preloaded_basis_fns)
        (
            params,
            opt_state,
            loss,
            pred_final_e_tot,
            pred_final_e_xc,
            volatility,
            grad_norm,
            update_norm,
            loss_components,
            prng_key,
        ) = jit_step(
            params,
            opt_state,
            targets,
            sys,
            targets.occupied_virtual_shape,
            jnp.asarray(learning_rate),
            relative_loss_weights,
            scf_decay_weights,
            P0,
            prng_key,
        )
        logger.deixc_loss(
            loss.item(),
            pred_final_e_tot.item(),
            pred_final_e_xc.item(),
            targets.total_energy.item(),
            targets.xc_energy.item(),
            loss_components.to_host(),  # type: ignore
            relative_loss_weights.to_host(),  # type: ignore
            f'{logging_prefix}/train',
            psys.idx,
            volatility.item(),
            max_energy_volatility,
        )
        return params, opt_state, prng_key, grad_norm, update_norm

    return step_fn


def get_eval_step_fn(
    logging_prefix: str,
    loss_fn: DynamicLossFn,
    initial_density_matrix_fn: InitialDensityMatrixFn,
    input_transform: ToJaxTransform,
    relative_loss_weights: RelativeDEILossWeights,
    scf_decay_weights: SCFDecayWeights,
    logger: Logger,
    max_energy_volatility: float,
    prng_key: PRNGKey,
) -> EvalStepFn:
    def step_fn(
        params: NnParams,
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
    ):
        init_densities, sys = input_transform(psys, preloaded_basis_fns)
        P0, _ = initial_density_matrix_fn(init_densities, None)
        loss, (pred_e_tot, pred_final_e_xc, pred_F, loss_components) = loss_fn(
            params,
            targets,
            P0,
            sys,
            targets.occupied_virtual_shape,
            relative_loss_weights,
            scf_decay_weights,
            prng_key,
        )
        predicted_homo_lumo_gap = homo_lumo_gap_fn(
            pred_F, sys.fock_tensors.diagonal_overlap, int(sys.n_electrons)
        )
        mae_homo_lumo_gap = abs(predicted_homo_lumo_gap - targets.homo_lumo_gap).item()
        logger.log(
            {f'observables/{logging_prefix}/val/homo_lumo_gap_mae': mae_homo_lumo_gap}
        )
        volatility = energy_volatility(pred_e_tot)
        logger.deixc_loss(
            loss.item(),
            pred_e_tot[-1].item(),
            pred_final_e_xc.item(),
            targets.total_energy.item(),
            targets.xc_energy.item(),
            loss_components.to_host(),  # type: ignore
            relative_loss_weights.to_host(),  # type: ignore
            f'{logging_prefix}/val',
            psys.idx,
            volatility.item(),
            max_energy_volatility,
        )

    return step_fn
