from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp
from flax.struct import dataclass as flax_dataclass

from egxc.systems import System
from egxc.utils.typing import (
    PRECISION,
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatSCF,
    FloatSCFx2xBxB,
    FloatSCFxBxB,
    UInt1,
)

from .config import LossConfig
from .density import density_loss
from .scalar import ScalarLossConfig, scalar_loss


def zero_fn(*args) -> Float1:
    return jnp.array(0, dtype=PRECISION.loss)


@flax_dataclass
class LossFns:
    energy: Callable[[Float1, FloatSCF, UInt1], Float1]
    density: (
        Callable[[FloatBxB, FloatSCFxBxB, System], Float1]
        | Callable[[Float2xBxB, FloatSCFx2xBxB, System], Float1]
    )


def get_loss_fns(config: LossConfig) -> LossFns:
    decay_factors = config.decay_factors

    _energy_loss = zero_fn  # type: ignore
    _density_loss = zero_fn  # type: ignore

    if config.weights.energy > 0.0:

        def _energy_loss(
            target: Float1, prediction: FloatSCF, n_electrons: UInt1
        ) -> Float1:
            _cfg = ScalarLossConfig(measure='mse', scale_per_electron=True)
            out = decay_factors * scalar_loss(target, prediction, n_electrons, _cfg)
            return config.weights.energy * out.sum()

    if config.weights.density > 0.0:

        def _density_loss(  # type: ignore
            target: FloatBxB | Float2xBxB,
            prediction: FloatSCFxBxB | FloatSCFx2xBxB,
            sys: System,
        ) -> Float1:
            temp = partial(
                density_loss,
                config=config.density,
                reference_basis_is_same=config.reference_basis_is_same,
            )
            out = jax.vmap(
                temp,
                in_axes=(None, 0, None),
            )(
                target,
                prediction,
                sys,
            )
            return config.weights.density * jnp.sum(decay_factors * out)

    return LossFns(_energy_loss, _density_loss)
