from typing import Any, Callable, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils.typing import PRECISION, Float1, FloatN, FloatNx3, FloatNx7
from egxc.xc_energy.features import (
    ueg_e_x,
    ueg_spin_pol_e_x_factor,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.learnable.nn.layers import ScaledSigmoid


def _input_transform(
    n_up: FloatN,
    n_down: FloatN,
    grad_n_up: FloatNx3,
    grad_n_down: FloatNx3,
    tau_up: FloatN,
    tau_down: FloatN,
) -> Tuple[FloatNx7, FloatNx7]:
    abs_grad_up_sq = jnp.sum(grad_n_up**2, axis=-1)
    abs_grad_dn_sq = jnp.sum(grad_n_down**2, axis=-1)
    abs_grad_total_sq = jnp.sum((grad_n_up + grad_n_down) ** 2, axis=-1)
    x = jnp.stack(
        [
            n_up,
            n_down,
            abs_grad_up_sq,
            abs_grad_dn_sq,
            tau_up,
            tau_down,
            abs_grad_total_sq,
        ],
        axis=-1,
    )
    x = jnp.log(jnp.abs(x) + 1.0e-5)
    feats_ab = x
    feats_ba = x[:, [1, 0, 3, 2, 5, 4, 6]]
    return feats_ab, feats_ba


class Skala(BaseEnergyFunctional):
    """
    JAX port of the Skala functional (non-local disabled).
    Architecture: two-layer input MLP -> symmetrize -> four-layer output MLP -> ScaledSigmoid.

    Reference: https://github.com/microsoft/skala/blob/main/src/skala/functional/model.py
    """

    n_layers: int = 6  # 1 input 4 hidden 1 output
    hidden_dim: int = 256
    activation: Callable[[jax.Array], jax.Array] = nn.silu

    is_graph_based = False
    # Requires spin-resolved density features (n_up, n_dn, grad_n_up, grad_n_dn, tau_up, tau_dn)
    requires_spin_resolved_features: bool = True

    def setup(self) -> None:
        xavier = (
            nn.initializers.xavier_uniform()
        )  # reference implementation explicitly uses xavier_uniform

        # Input model: Dense -> SiLU -> Dense -> SiLU
        self.input_model = nn.Sequential(
            [
                nn.Dense(
                    self.hidden_dim,
                    dtype=PRECISION.local_nn,
                    kernel_init=xavier,
                ),
                self.activation,
                nn.Dense(
                    self.hidden_dim,
                    dtype=PRECISION.local_nn,
                    kernel_init=xavier,
                ),
                self.activation,
            ]
        )

        # Output model: 4x [Dense -> SiLU] -> Dense(1) -> ScaledSigmoid(2.0)
        out_layers = []
        for _ in range(self.n_layers - 3):  # default: 4 hidden layers
            out_layers.extend(
                [
                    nn.Dense(
                        self.hidden_dim,
                        dtype=PRECISION.local_nn,
                        kernel_init=xavier,
                    ),
                    self.activation,
                ]
            )
        out_layers.append(nn.Dense(1, dtype=PRECISION.local_nn, kernel_init=xavier))
        out_layers.append(
            ScaledSigmoid(initial_scale=2.0)
        )  # default scale is 2.0 as in the reference implementation
        self.output_model = nn.Sequential(out_layers)

    def __call__(
        self, weights: FloatN, *feats: FloatN, **non_local_kwargs: Any
    ) -> Float1:
        # Expect spin-resolved features (n_up,n_dn,s_up,s_dn,tau_up,tau_dn)
        return self.integrate_spin_resolved_energy_density(weights, *feats)

    def learned_enhancement_factor(
        self, feats_ab: FloatNx7, feats_ba: FloatNx7
    ) -> jax.Array:
        x_ab = self.input_model(feats_ab)
        x_ba = self.input_model(feats_ba)
        x_sym = 0.5 * (x_ab + x_ba)
        # TODO: add non-local contribution here later
        F = self.output_model(x_sym)
        return F[:, 0]

    def xc_energy_density_spin_resolved(  # type: ignore
        self,
        n_up: FloatN,
        n_down: FloatN,
        grad_n_up: FloatNx3,
        grad_n_down: FloatNx3,
        tau_up: FloatN,
        tau_down: FloatN,
    ) -> FloatN:
        feats_ab, feats_ba = _input_transform(
            n_up, n_down, grad_n_up, grad_n_down, tau_up, tau_down
        )
        F = self.learned_enhancement_factor(feats_ab, feats_ba)
        # Per-particle LSDA exchange baseline (spin-polarized UEG), then enhance
        n = n_up + n_down
        zeta = (n_up - n_down) / n  # min density threshold is 1e-15 by default
        e_x_lsda_per_particle = ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta)
        return e_x_lsda_per_particle * F
