from dataclasses import dataclass
from typing import Any, Dict, Literal

import jax.numpy as jnp
import numpy as onp

from egxc.training.loss.density import DensityFieldLossConfig
from egxc.utils.typing import (
    PRECISION,
    FloatSCF,
)


@dataclass
class RelativeLossWeights:
    energy: float
    density: float

    def __post_init__(self):
        assert all(w >= 0 for w in self.__dict__.values()), (
            'Weights must be non-negative.'
        )


def get_decay_factors(
    cycles: int,
    discard_first_n: int,
    key: Literal['constant', 'dick2021', 'li2021', 'egxc2024', 'only_final'],
) -> FloatSCF:
    def _decay_constant(cycles: int, discard_first_n: int) -> FloatSCF:
        out = onp.ones(cycles)
        out[:discard_first_n] = 0.0
        return jnp.array(out, dtype=PRECISION.loss)

    def _decay_only_final(cycles: int) -> FloatSCF:
        out = onp.zeros(cycles)
        out[-1] = 1.0
        return jnp.array(out, dtype=PRECISION.loss)

    def _decay_dick2021(cycles: int, discard_first_n: int = 10) -> FloatSCF:
        """
        Weights as proposed by Dick et al. https://doi.org/10.1103/PhysRevB.104.L161109.
        """

        def w_j(j):
            return (
                j - discard_first_n
            ) ** 4  # NOTE: maybe there was a typo in the paper and this should be squared

        out = onp.fromfunction(w_j, (cycles,))
        out[:discard_first_n] = 0.0  # type: ignore
        return jnp.array(out, dtype=PRECISION.loss)

    def _decay_li2021(cycles: int, discard_first_n: int = 10) -> FloatSCF:
        """
        Weights from original publication on KS-Regularizer.
        """

        def w_j(j):
            return 0.9 ** (cycles - j) * (j > discard_first_n)

        return jnp.fromfunction(w_j, (cycles,), dtype=PRECISION.loss)

    def _decay_egxc2024(cycles: int, discard_first_n: int) -> FloatSCF:
        consider_last_n = cycles - discard_first_n
        weights = jnp.arange(1, consider_last_n + 1) ** 2
        return jnp.hstack((jnp.zeros(discard_first_n), weights), dtype=PRECISION.loss)

    match key:
        case 'constant':
            out = _decay_constant(cycles, discard_first_n)
        case 'dick2021':
            out = _decay_dick2021(cycles, discard_first_n)
        case 'li2021':
            out = _decay_li2021(cycles, discard_first_n)
        case 'egxc2024':
            out = _decay_egxc2024(cycles, discard_first_n)
        case 'only_final':
            out = _decay_only_final(cycles)
        case _:
            raise ValueError(f'Unknown scf trajectory loss decay type: {key}')
    return out / out.sum()  # normalize to 1


@dataclass
class LossConfig:
    """
    Configures the loss weights of the Kohn-Sham (KS) Regularizer
    See: Li et al. K. "Kohn-Sham Equations as Regularizer: Building Prior Knowledge into
    Machine-Learned Physics." Phys. Rev. Lett. 2021, 126 (3), 036401.
    https://doi.org/10.1103/PhysRevLett.126.036401.
    """

    weights: RelativeLossWeights
    decay_factors: FloatSCF
    max_energy_volatility: float
    reference_basis_is_same: bool
    density: DensityFieldLossConfig

    @classmethod
    def create(
        cls,
        cycles: int,
        discard_first_n: int,
        decay_type: Literal['dick2021', 'li2021', 'egxc2024', 'only_final'],
        relative_weights: Dict[str, float],
        max_energy_volatility: float,
        density: Dict[str, Any],
        reference_basis_is_same: bool = False,
    ) -> 'LossConfig':
        density_config = DensityFieldLossConfig(**density)
        return cls(
            weights=RelativeLossWeights(**relative_weights),
            decay_factors=LossConfig.get_decay_factors(
                cycles, discard_first_n, decay_type
            ),
            max_energy_volatility=max_energy_volatility,
            density=density_config,
            reference_basis_is_same=reference_basis_is_same,
        )

    @staticmethod
    def get_decay_factors(cycles, discard_first_n, decay_type) -> FloatSCF:
        return get_decay_factors(cycles, discard_first_n, decay_type)
