from functools import partial
from typing import Callable, Literal

import jax
import jax.numpy as jnp
from flax.struct import dataclass

from deixc.training.loss.config import DynamicLossConfig, StaticLossConfig
from egxc.systems import System
from egxc.training.loss import density_loss, scalar_loss
from egxc.utils.typing import (
    PRECISION,
    Float1,
    Float2xBxB,
    Float2xOxV,
    FloatAx3,
    FloatB,
    FloatBxB,
    FloatOxV,
    FloatRefSCF,
    FloatRefSCFx2xBxB,
    FloatRefSCFx2xOxV,
    FloatRefSCFxAx3,
    FloatRefSCFxBxB,
    FloatRefSCFxOxV,
    FloatTx2xOxV,
    FloatTxOxV,
    UInt1,
)

from .force import force_loss
from .orbital_rotation import (
    orbital_rotation_gradient_loss,
    orbital_rotation_hessian_loss,
)
from .xc_potential import xc_potential_loss

EnergyLossFn = Callable[[Float1, Float1, UInt1], Float1]
SCFVecEnergyLossFn = Callable[[FloatRefSCF, FloatRefSCF, UInt1, FloatRefSCF], Float1]

ForcesLossFn = Callable[[FloatAx3, FloatAx3], Float1]
SCFVecForcesLossFn = Callable[[FloatRefSCFxAx3, FloatRefSCFxAx3, FloatRefSCF], Float1]

# Unified field loss type aliases
FieldLossFn = (
    Callable[[FloatBxB, FloatBxB, System], Float1]
    | Callable[[Float2xBxB, Float2xBxB, System], Float1]
)
SCFVecFieldLossFn = (
    Callable[[FloatRefSCFxBxB, FloatRefSCFxBxB, System, FloatRefSCF], Float1]
    | Callable[[FloatRefSCFx2xBxB, FloatRefSCFx2xBxB, System, FloatRefSCF], Float1]
)

OrbRotGradLossFn = (
    Callable[[FloatBxB, FloatBxB, FloatBxB, int, FloatB], Float1]
    | Callable[[Float2xBxB, Float2xBxB, Float2xBxB, int, FloatB], Float1]
)

SCFVecOrbRotGradLossFn = (
    Callable[
        [FloatRefSCFxBxB, FloatRefSCFxBxB, FloatRefSCFxBxB, int, FloatRefSCF, FloatB],
        Float1,
    ]
    | Callable[
        [
            FloatRefSCFx2xBxB,
            FloatRefSCFx2xBxB,
            FloatRefSCFx2xBxB,
            int,
            FloatRefSCF,
            FloatB,
        ],
        Float1,
    ]
)

VecOrbRotHessianLossFn = Callable[
    [FloatTxOxV | FloatTx2xOxV, FloatTxOxV | FloatTx2xOxV, FloatB], Float1
]


DensHessianLossFn = (
    Callable[[FloatOxV, FloatOxV], Float1] | Callable[[Float2xOxV, Float2xOxV], Float1]
)


@dataclass
class LossFns:
    energy: EnergyLossFn
    forces: ForcesLossFn
    density: FieldLossFn
    xc_potential: FieldLossFn
    orbital_rotation_gradient: OrbRotGradLossFn
    orbital_rotation_hessian: VecOrbRotHessianLossFn


@dataclass
class SCFVectorizedLossFns:
    energy: SCFVecEnergyLossFn
    forces: SCFVecForcesLossFn
    density: SCFVecFieldLossFn
    xc_potential: SCFVecFieldLossFn
    orbital_rotation_gradient: SCFVecOrbRotGradLossFn
    orbital_rotation_hessian: VecOrbRotHessianLossFn


def vectorize_loss_fns(
    loss_fns: LossFns,
    train_type: Literal['static', 'dynamic', 'dynamic_with_reference'],
) -> SCFVectorizedLossFns:
    """
    vectorizes the loss functions along the SCF cycle dimension.
    If the loss function is vectorized, it is vectorized along the SCF cycle dimension.
    If the loss function is not vectorized, a function with the same signature is returned,
    but only the final cycle is considered and the scf_decay argument is ignored.
    The vectorized loss function is then returned.

    Args:
        loss_fns: The loss functions to vectorize.
        energy: Whether to vectorize the energy loss function.
        forces: Whether to vectorize the forces loss function.
        xc_potential: Whether to vectorize the XC potential loss function.
        orbital_rotation_gradient: Whether to vectorize the orbital rotation gradient loss function.
        orbital_rotation_hessian: Whether to vectorize the orbital rotation Hessian loss function.
    """

    # ---------------- Energy ----------------

    def energy_fn(
        target: FloatRefSCF,
        prediction: FloatRefSCF,
        n_electrons: UInt1,
        scf_decay: FloatRefSCF,
    ) -> Float1:
        if train_type == 'dynamic':
            in_axes = (None, 0, None)
        else:  # static or dynamic_with_reference
            in_axes = (0, 0, None)
        out = jax.vmap(loss_fns.energy, in_axes=in_axes)(target, prediction, n_electrons)
        return jnp.sum(out * scf_decay)

    # ---------------- Forces ----------------
    def forces_fn(
        target: FloatRefSCFxAx3,
        prediction: FloatRefSCFxAx3,
        scf_decay: FloatRefSCF,
    ) -> Float1:
        if train_type == 'static':
            in_axes = (0, 0)
        else:  # dynamic
            in_axes = (None, 0)
        out = jax.vmap(loss_fns.forces, in_axes=in_axes)(target, prediction)
        return jnp.sum(out * scf_decay)

    # ---------------- Density ----------------
    def density_fn(
        target: FloatRefSCFxBxB | FloatRefSCFx2xBxB,
        prediction: FloatRefSCFxBxB | FloatRefSCFx2xBxB,
        sys: System,
        scf_decay: FloatRefSCF,
    ) -> Float1:
        if train_type == 'static':
            in_axes = (0, 0, None)
        else:  # dynamic
            in_axes = (None, 0, None)
        out = jax.vmap(loss_fns.density, in_axes=in_axes)(target, prediction, sys)
        return jnp.sum(out * scf_decay)

    # ---------------- XC potential ----------------
    def xc_potential_fn(
        target: FloatRefSCFxBxB | FloatRefSCFx2xBxB,
        prediction: FloatRefSCFxBxB | FloatRefSCFx2xBxB,
        sys: System,
        scf_decay: FloatRefSCF,
    ) -> Float1:
        if train_type == 'dynamic':
            in_axes = (None, 0, None)
        else:  # static or dynamic_with_reference
            in_axes = (0, 0, None)
        out = jax.vmap(loss_fns.xc_potential, in_axes=in_axes)(target, prediction, sys)
        return jnp.sum(out * scf_decay)

    # -------- Orbital rotation direction --------
    if not train_type == 'dynamic_with_reference':

        @partial(jax.jit, static_argnames=['n_occ'])
        def orbital_rotation_gradient(  # type: ignore
            target_directions: FloatRefSCFxOxV | FloatRefSCFx2xOxV,
            predicted_xc_potential_matrices: FloatRefSCFxBxB | FloatRefSCFx2xBxB,
            mo_coeffs: FloatRefSCFxBxB | FloatRefSCFx2xBxB,
            n_occ: int,
            scf_decay: FloatRefSCF,
            orbital_energies: FloatB,
        ) -> Float1:
            out = jax.vmap(
                loss_fns.orbital_rotation_gradient, in_axes=(0, 0, 0, None, None)
            )(
                target_directions,
                predicted_xc_potential_matrices,
                mo_coeffs,
                n_occ,
                orbital_energies,
            )
            return jnp.sum(out * scf_decay)
    else:

        def orbital_rotation_gradient(
            target_fock_or_xc_pot: FloatRefSCFxBxB,
            predicted_fock_or_xc_pot: FloatRefSCFxBxB,
            predicted_mo_coeffs: FloatRefSCFxBxB,
            n_occ: int,
            scf_decay: FloatRefSCF,
            orbital_energies: FloatB,
        ) -> Float1:
            out = jax.vmap(
                loss_fns.orbital_rotation_gradient, in_axes=(0, 0, 0, None, None)
            )(
                target_fock_or_xc_pot,
                predicted_fock_or_xc_pot,
                predicted_mo_coeffs,
                n_occ,
                orbital_energies,
            )
            return jnp.sum(out * scf_decay)

    return SCFVectorizedLossFns(
        energy=energy_fn,
        forces=forces_fn,
        density=density_fn,
        xc_potential=xc_potential_fn,
        orbital_rotation_gradient=orbital_rotation_gradient,
        orbital_rotation_hessian=loss_fns.orbital_rotation_hessian,
    )


def get_loss_fns(
    config: StaticLossConfig | DynamicLossConfig,
    train_type: Literal['static', 'dynamic', 'dynamic_with_reference'],
) -> SCFVectorizedLossFns:
    _energy = partial(
        scalar_loss,
        config=config.energy,
    )

    def _forces(target: FloatAx3, prediction: FloatAx3):
        return force_loss(target, prediction)

    xc_pot = partial(
        xc_potential_loss,
        config=config.xc_potential,
        reference_basis_is_same=config.reference_basis_is_same,
    )

    orbital_rotation_gradient = partial(
        orbital_rotation_gradient_loss,
        config=config.orbital_rotation_gradient,
    )

    if train_type == 'static':
        density = lambda *args: jnp.array(0, dtype=PRECISION.loss)
    else:
        assert isinstance(config, DynamicLossConfig)
        density = partial(
            density_loss,
            config=config.density,
            reference_basis_is_same=config.reference_basis_is_same,
        )
    if config.weights.orbital_rotation_hessian > 0.0:
        assert train_type == 'dynamic_with_reference', (
            'orbital rotation Hessian is only supported in dynamic with reference training'
        )
        assert isinstance(config, DynamicLossConfig)
        orbital_rotation_hessian = partial(
            orbital_rotation_hessian_loss,
            config=config.orbital_rotation_hessian,
        )
    else:
        orbital_rotation_hessian = lambda *args: jnp.array(0, dtype=PRECISION.loss)

    loss_fns = LossFns(
        _energy,
        _forces,
        density,
        xc_pot,
        orbital_rotation_gradient,
        orbital_rotation_hessian,
    )
    return vectorize_loss_fns(loss_fns, train_type)
