from typing import NamedTuple, Protocol, Tuple

from flax.struct import dataclass
from optax import OptState

from deixc.dataset import DEIXCTargets
from egxc.discretization import PreloadedGTOBasis
from egxc.systems import PreloadSystem
from egxc.training.utils.ema import EMA
from egxc.utils.typing import Float1, NnParams, PRNGKey


class TrainStepFn(Protocol):
    def __call__(
        self,
        params: NnParams,
        opt_state: Tuple[OptState, OptState, EMA],
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
        prng_key: PRNGKey,
    ) -> Tuple[NnParams, Tuple[OptState, OptState, EMA], PRNGKey, Float1, Float1]: ...


class EvalStepFn(Protocol):
    def __call__(
        self,
        params: NnParams,
        psys: PreloadSystem,
        preloaded_basis_fns: PreloadedGTOBasis,
        targets: DEIXCTargets,
    ) -> None: ...


LossTuple = NamedTuple(
    'LossTuple',
    [
        ('xc_energy', float),
        ('forces', float),
        ('xc_potential', float),
        ('orbital_rotation_gradient', float),
        ('orbital_rotation_hessian', float),
        ('total_energy', float),  # only in dynamic training stage
        ('density', float),  # only in dynamic training stage
    ],
)


@dataclass
class LossComponents:
    xc_energy: Float1
    forces: Float1
    xc_potential: Float1
    orbital_rotation_gradient: Float1
    orbital_rotation_hessian: Float1
    total_energy: Float1  # only in dynamic training stage
    density: Float1  # only in dynamic training stage

    def to_host(self) -> LossTuple:
        return LossTuple(
            xc_energy=self.xc_energy.item(),
            forces=self.forces.item(),
            xc_potential=self.xc_potential.item(),
            orbital_rotation_gradient=self.orbital_rotation_gradient.item(),
            orbital_rotation_hessian=self.orbital_rotation_hessian.item(),
            total_energy=self.total_energy.item(),
            density=self.density.item(),
        )
