from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils.typing import PRECISION, FloatN, FloatNxF, FloatNxT
from egxc.xc_energy.features import (
    transform_tau_to_alpha,
    ueg_e_x,
    ueg_spin_pol_e_x_factor,
    ueg_tau,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical.lsda import pw92_correlation_energy_density
from egxc.xc_energy.functionals.classical.mgga import e_c_scan, e_x_scan, e_xc_scan


def log_normalize(x: FloatN, mean: float, std: float) -> FloatN:
    return (jnp.log1p(x) - mean) / std


def feature_transform(
    n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
) -> Tuple[FloatN, FloatN, FloatN, FloatN]:
    n_t = n ** (1 / 3)
    tau_t = (tau - ueg_tau(n)) / ueg_tau(n)
    n_t = log_normalize(
        n_t, 0.7, 0.5
    )  # manually tuned hyperparameters informed by single QM9 sample
    s_t = log_normalize(
        s, 0.7, 1.5
    )  # manually tuned hyperparameters informed by single QM9 sample
    tau_t = log_normalize(
        tau_t, 2.0, 4.0
    )  # manually tuned hyperparameters informed by single QM9 sample

    return n_t, zeta, s_t, tau_t


silu_uniform_init = nn.initializers.variance_scaling(
    2.2,  # fitting silu activation
    'fan_in',
    'uniform',
)


class DEIXC(BaseEnergyFunctional):
    """
    local_n_layers: int
        Number of layers in the feedforward neural network (including the output layer).
    local_hidden_dim: int
        Number of hidden units per layer of the network.
    """

    local_n_layers: int  # including the final output layer
    local_hidden_dim: int

    # initialization hyperparameters informed by single QM9 sample
    F_mean = 1.16
    is_graph_based: bool = False

    def xc_energy_density(  # type: ignore
        self,
        n: FloatN,
        zeta: FloatN,
        s: FloatN,
        tau: FloatN,
    ) -> FloatN:
        # self.__analyze_normalization(n, zeta, s, tau, weights)
        e_ueg_xc = ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta)
        x_local = self.local_input_transform(n, zeta, s, tau)
        return e_ueg_xc * self.learned_enhancement_factors(x_local)

    def non_local_xc_energy_density(
        self,
        non_local_grid_features: FloatNxF,
        n: FloatN,
        zeta: FloatN,
        s: FloatN,
        tau: FloatN,
    ) -> FloatN:
        e_ueg_xc = ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta)
        x_local = self.local_input_transform(n, zeta, s, tau)
        x_non_local = self.non_local_input_transform(non_local_grid_features)
        return e_ueg_xc * self.learned_enhancement_factors(x_local + x_non_local)

    @nn.compact
    def local_input_transform(
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
    ) -> FloatNxT:
        n_t, zeta, s_t, tau_t = feature_transform(n, zeta, s, tau)
        x_local = jnp.stack([n_t, zeta, s_t, tau_t], axis=-1)
        x_local = nn.Dense(
            self.local_hidden_dim,
            kernel_init=silu_uniform_init,
            name='local_input_transform',
        )(nn.silu(x_local))
        return x_local

    @nn.compact
    def non_local_input_transform(self, non_local_grid_features: FloatNxF) -> FloatNxT:
        x_non_local = nn.Dense(
            self.local_hidden_dim,
            kernel_init=nn.initializers.zeros,
            use_bias=False,  # would be redundant with bias in local input transform
            name='non_local_input_transform',
        )(nn.silu(non_local_grid_features))
        return x_non_local

    @nn.compact
    def learned_enhancement_factors(self, x: FloatNxT) -> FloatN:
        F_mean = self.param(
            'F_mean', nn.initializers.constant(self.F_mean), (), PRECISION.local_nn
        )
        out = jnp.zeros(x.shape[0], dtype=PRECISION.local_nn)
        for _ in range(
            self.local_n_layers - 2
        ):  # maximum depth increases by 1 due to final output readout
            x = nn.Dense(self.local_hidden_dim, kernel_init=silu_uniform_init)(nn.silu(x))
            out += nn.Dense(1, kernel_init=nn.initializers.zeros, use_bias=False)(x)[:, 0]
        return F_mean + out

    def __analyze_normalization(
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN, weights: FloatN
    ) -> None:
        """
        Manual debugging utility to check the normalization of the DEIXC enhancement factors.
        To use it uncomment the corresponding line in the xc_energy_density method and
        temporarily pass the weights to the xc_energy_density method in the base class.
        """
        n_t = n ** (1 / 3)
        tau_t = (tau - ueg_tau(n)) / ueg_tau(n)

        # print log1p(input) means and variances
        jax.debug.print(
            'log1p(n_t) mean: {mean}, log1p(n_t) std: {std}',
            mean=jnp.log1p(n_t).mean(),
            std=jnp.log1p(n_t).std(),
        )
        jax.debug.print(
            'log1p(s) mean: {mean}, log1p(s) std: {std}',
            mean=jnp.log1p(s).mean(),
            std=jnp.log1p(s).std(),
        )
        jax.debug.print(
            'log1p(tau_t) mean: {mean}, log1p(tau_t) std: {std}',
            mean=jnp.log1p(tau_t).mean(),
            std=jnp.log1p(tau_t).std(),
        )

        x_weights = n * weights * ueg_e_x(n)
        c_weights = n * weights * pw92_correlation_energy_density(n, zeta)
        tot_weights = n * weights * ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta)

        alpha = transform_tau_to_alpha(n, zeta, s, tau)

        # compute mean and variance of SCAN enhancement factors:
        x_factor = e_x_scan(n, s, alpha) / ueg_e_x(n)
        c_factor = e_c_scan(n, zeta, s, alpha) / pw92_correlation_energy_density(n, zeta)

        tot_factor = e_xc_scan(n, zeta, s, tau) / (
            ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta)
        )

        # compute mean and variance of DEIXC enhancement factors:
        x_factor_mean = (x_factor * x_weights).sum() / x_weights.sum()
        c_factor_mean = (c_factor * c_weights).sum() / c_weights.sum()
        tot_factor_mean = (tot_factor * tot_weights).sum() / tot_weights.sum()
        x_factor_std = jnp.sqrt(
            ((x_factor - x_factor_mean) ** 2 * x_weights).sum() / x_weights.sum()
        )
        c_factor_std = jnp.sqrt(
            ((c_factor - c_factor_mean) ** 2 * c_weights).sum() / c_weights.sum()
        )
        tot_factor_std = jnp.sqrt(
            ((tot_factor - tot_factor_mean) ** 2 * tot_weights).sum() / tot_weights.sum()
        )

        jax.debug.print(
            'x_factor_mean: {mean}, x_factor_std: {std}',
            mean=x_factor_mean,
            std=x_factor_std,
        )
        jax.debug.print(
            'c_factor_mean: {mean}, c_factor_std: {std}',
            mean=c_factor_mean,
            std=c_factor_std,
        )
        jax.debug.print(
            'tot_factor_mean: {mean}, tot_factor_std: {std}',
            mean=tot_factor_mean,
            std=tot_factor_std,
        )
