from typing import Any, Dict, Literal, NamedTuple, Tuple

from flax.struct import dataclass

from deixc.training.loss.orbital_rotation import (
    OrbitalRotationHessianLossConfig,
    OrbitalRotationTensorLossConfig,
)
from deixc.training.loss.xc_potential import (
    XCPotentialFieldLossConfig,
)
from egxc.training.loss import DensityFieldLossConfig
from egxc.training.loss.config import get_decay_factors
from egxc.training.loss.scalar import ScalarLossConfig
from egxc.utils.typing import (
    Float1x1,
    FloatRefSCF,
)

WeightDecayType = Literal['only_final', 'constant', 'dick2021', 'li2021', 'egxc2024']
WeightDecaySpec = WeightDecayType | Tuple[WeightDecayType, int]


@dataclass
class SCFDecayWeights:
    xc_energy: FloatRefSCF | Float1x1
    forces: FloatRefSCF | Float1x1
    xc_potential: FloatRefSCF | Float1x1
    orbital_rotation_gradient: FloatRefSCF | Float1x1
    total_energy: FloatRefSCF | Float1x1
    density: FloatRefSCF | Float1x1


RelativeLossWeightsTuple = NamedTuple(
    'RelativeLossWeightsTuple',
    [
        ('xc_energy', float),
        ('forces', float),
        ('xc_potential', float),
        ('orbital_rotation_hessian', float),
        ('orbital_rotation_gradient', float),
        ('total_energy', float),  # only in dynamic training stage
        ('density', float),  # only in dynamic training stage
    ],
)


@dataclass
class RelativeDEILossWeights:
    xc_energy: float
    forces: float
    xc_potential: float
    orbital_rotation_gradient: float
    orbital_rotation_hessian: float  # only in dynamic training stage
    total_energy: float  # only in dynamic training stage
    density: float  # only in dynamic training stage

    def to_host(self) -> RelativeLossWeightsTuple:
        return RelativeLossWeightsTuple(
            xc_energy=self.xc_energy,
            forces=self.forces,
            xc_potential=self.xc_potential,
            orbital_rotation_hessian=self.orbital_rotation_hessian,
            orbital_rotation_gradient=self.orbital_rotation_gradient,
            total_energy=self.total_energy,  # only in dynamic training stage
            density=self.density,  # only in dynamic training stage
        )


@dataclass
class IsVectorized:
    xc_energy: bool
    forces: bool
    xc_potential: bool
    orbital_rotation_gradient: bool
    total_energy: bool
    density: bool


@dataclass
class SCFLossVectorization:
    is_vectorized: IsVectorized
    scf_decay_weights: SCFDecayWeights

    @classmethod
    def create(
        cls,
        xc_energy: WeightDecaySpec,
        forces: WeightDecaySpec,
        xc_potential: WeightDecaySpec,
        orbital_rotation_gradient: WeightDecaySpec,
        total_energy: WeightDecaySpec,
        density: WeightDecaySpec,
        n_cycles: int,
        discard_first_n: int | None,
    ) -> 'SCFLossVectorization':
        is_vectorized = IsVectorized(
            xc_energy=True,
            forces=True,
            xc_potential=True,
            orbital_rotation_gradient=True,
            total_energy=True,
            density=True,
        )

        def _get_decay_weights(
            spec: WeightDecaySpec,
        ) -> FloatRefSCF | Float1x1:
            """Return SCF-decay weights aligned with the SCF length.

            Always delegate to the shared get_decay_factors so that
            'only_final' produces a proper one-hot vector of shape (SCF,)
            with weight on the final step, instead of a length-1 vector
            that would incorrectly broadcast across steps.
            """
            if isinstance(spec, (tuple, list)):
                key, local_discard_first_n = spec
            else:
                assert discard_first_n is not None, (
                    'discard_first_n must be provided if spec is not a tuple'
                )
                key, local_discard_first_n = spec, discard_first_n
            return get_decay_factors(n_cycles, local_discard_first_n, key)

        scf_decay_weights = SCFDecayWeights(
            xc_energy=_get_decay_weights(xc_energy),
            forces=_get_decay_weights(forces),
            xc_potential=_get_decay_weights(xc_potential),
            orbital_rotation_gradient=_get_decay_weights(orbital_rotation_gradient),
            total_energy=_get_decay_weights(total_energy),
            density=_get_decay_weights(density),
        )
        return SCFLossVectorization(is_vectorized, scf_decay_weights)


@dataclass
class BaseLossConfig:
    weights: RelativeDEILossWeights
    scf_loss_vectorization: SCFLossVectorization
    reference_basis_is_same: bool
    energy: ScalarLossConfig
    xc_potential: XCPotentialFieldLossConfig
    orbital_rotation_gradient: OrbitalRotationTensorLossConfig


@dataclass
class StaticLossConfig(BaseLossConfig):
    @classmethod
    def create(
        cls,
        relative_weights: Dict[str, float],
        vectorize_along_scf: Dict[str, WeightDecaySpec],
        reference_basis_is_same: bool,
        energy: Dict[str, Any],
        xc_potential: Dict[str, Any],
        orbital_rotation_gradient: Dict[str, Any],
        ref_scf_cycles: int,  # of the reference calculation (static)
    ) -> 'StaticLossConfig':
        energy_config = ScalarLossConfig(**energy)
        xc_pot_config = XCPotentialFieldLossConfig(**xc_potential)
        orbital_rotation_gradient_config = OrbitalRotationTensorLossConfig(
            **orbital_rotation_gradient
        )
        loss_weights = RelativeDEILossWeights(
            density=0.0,
            total_energy=0.0,
            orbital_rotation_hessian=0.0,
            **relative_weights,
        )
        return cls(
            weights=loss_weights,
            scf_loss_vectorization=SCFLossVectorization.create(
                **vectorize_along_scf,
                density='only_final',  #  irrelevant for static training
                total_energy='only_final',  #  irrelevant for static training
                n_cycles=ref_scf_cycles,
                discard_first_n=0,
            ),
            reference_basis_is_same=reference_basis_is_same,
            energy=energy_config,
            xc_potential=xc_pot_config,
            orbital_rotation_gradient=orbital_rotation_gradient_config,
        )


@dataclass
class DynamicLossConfig(BaseLossConfig):
    orbital_rotation_hessian: OrbitalRotationHessianLossConfig
    max_energy_volatility: float
    with_dynamic_reference: bool
    density: DensityFieldLossConfig

    @classmethod
    def create(  # type: ignore[override]
        cls,
        relative_weights: Dict[str, float],
        vectorize_along_scf: Dict[str, WeightDecaySpec],
        reference_basis_is_same: bool,
        energy: Dict[str, Any],
        xc_potential: Dict[str, Any],
        orbital_rotation_gradient: Dict[str, Any],
        orbital_rotation_hessian: Dict[str, Any],
        scf_cycles: int,  # of the training calculation (dynamic)
        max_energy_volatility: float,  # mEh
        with_dynamic_reference: bool,
        density: Dict[str, Any],
    ) -> 'DynamicLossConfig':
        energy_config = ScalarLossConfig(**energy)
        density_config = DensityFieldLossConfig(**density)
        xc_pot_config = XCPotentialFieldLossConfig(**xc_potential)
        orbital_rotation_gradient_config = OrbitalRotationTensorLossConfig(
            **orbital_rotation_gradient
        )
        orbital_rotation_hessian_config = OrbitalRotationHessianLossConfig(
            **orbital_rotation_hessian
        )

        loss_weights = RelativeDEILossWeights(**relative_weights)
        vectorization = SCFLossVectorization.create(
            **vectorize_along_scf,
            n_cycles=scf_cycles,
            discard_first_n=None,
        )
        if not with_dynamic_reference:
            assert orbital_rotation_hessian_config.n_perturbations == 0, (
                'n_perturbations is not used for precomputed reference'
            )
            assert not vectorization.is_vectorized.orbital_rotation_gradient, (
                'This requires dynamic evaluation of the reference functional'
            )

        return cls(
            weights=loss_weights,
            scf_loss_vectorization=vectorization,
            reference_basis_is_same=reference_basis_is_same,
            energy=energy_config,
            xc_potential=xc_pot_config,
            orbital_rotation_gradient=orbital_rotation_gradient_config,
            orbital_rotation_hessian=orbital_rotation_hessian_config,
            max_energy_volatility=max_energy_volatility,
            with_dynamic_reference=with_dynamic_reference,
            density=density_config,
        )
