from typing import Generic, NamedTuple, TypeVar

import einops
import folx
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.struct import PyTreeNode, field
from jaxtyping import Array, Float, Integer

from neural_pfaffian.clipping import Clipping, Masking
from neural_pfaffian.hamiltonian import KineticEnergyOp, make_local_energy
from neural_pfaffian.mcmc import MetropolisHastings
from neural_pfaffian.nn.wave_function import (
    GeneralizedWaveFunction,
    WaveFunctionParameters,
)
from neural_pfaffian.overlap import OverlapPenalty, OverlapState
from neural_pfaffian.preconditioner import Identity, Preconditioner
from neural_pfaffian.sample_reweighting import (
    LOG_NORMALIZER_CONSTANTS_KEY,
    ReweightingMode,
    compute_log_abs_wf_overlap,
    compute_log_density_overlap,
    compute_logpsi,
    compute_reweighting_factor,
    compute_sample_mask,
    convert_reweighted_tensor_to_unreweighted,
    get_normalizing_constant_ratios,
    update_normalizing_constant_ratios,
)
from neural_pfaffian.spin_operator import SpinPenalty
from neural_pfaffian.systems import Systems
from neural_pfaffian.utils import EMA, RollingAverage
from neural_pfaffian.utils.jax_utils import (
    REPLICATE_SPEC,
    SerializeablePyTree,
    distribute_keys,
    jit,
    pmax_if_pmap,
    pmean,
    pmean_if_pmap,
    psum,
    shmap,
)
from neural_pfaffian.utils.segment_utils import unsegment_axis
from neural_pfaffian.utils.summary_stats import (
    weighted_centering,
    weighted_mean,
    weighted_std,
)
from neural_pfaffian.utils.tree_utils import (
    is_tree_finite,
    tree_add,
    tree_mul,
    tree_squared_norm,
)

LocalEnergy = Float[Array, 'batch_size n_mols']
ReweightedLocalEnergy = Float[Array, '(sample_src*batch_size) n_mols']
"""Tensor of local energies, where each walkers' local energy has been evaluated on all states.
The batch dimension is now segmented by the sample source, i.e. the wave function state from
which the sample is sourced."""
S = TypeVar('S', bound=Systems)
SMOOTH_DATA_KEY = 'smooth_data'


O = TypeVar('O')
COrb = TypeVar('COrb')
OS = TypeVar('OS')
PS = TypeVar('PS')


class VMCState(Generic[PS], SerializeablePyTree):
    params: WaveFunctionParameters
    grad_norm_ema: EMA[Float[Array, '']]
    optimizer: optax.OptState
    preconditioner: PS
    overlap: OverlapState | None
    step: Integer[Array, '']
    epoch: Integer[Array, '']


class SmoothData(PyTreeNode):
    energy: Float[Array, ' n_mols']
    std: Float[Array, ' n_mols']

    @classmethod
    def init_systems(cls, systems: S) -> S:
        if SMOOTH_DATA_KEY not in systems.mol_data:
            smooth_data = cls(
                energy=jnp.zeros(systems.n_mols, dtype=jnp.float32),
                std=jnp.zeros(systems.n_mols, dtype=jnp.float32),
            )
            smooth_data = RollingAverage.init(smooth_data)
            systems = systems.set_mol_data(SMOOTH_DATA_KEY, smooth_data)
        return systems


class EvalData(NamedTuple):
    energy: Float[Array, ' n_mols']
    energy_sum: Float[Array, ' n_mols']
    energy_squared_sum: Float[Array, ' n_mols']
    overlap: Float[Array, 'n_mol_ids n_states n_states'] | None
    log_density_overlap: Float[Array, 'n_mol_ids n_states n_states'] | None
    log_abs_wf_overlap: Float[Array, 'n_mol_ids n_states n_states'] | None


class VMC(Generic[PS, O, COrb, OS, S], PyTreeNode):
    wave_function: GeneralizedWaveFunction[O, COrb, OS, S] = field(pytree_node=False)
    preconditioner: Preconditioner[PS] = field(pytree_node=False)
    optimizer: optax.GradientTransformation = field(pytree_node=False)
    sampler: MetropolisHastings = field(pytree_node=False)
    clipping: Clipping = field(pytree_node=False)
    masking: Masking = field(pytree_node=False)
    overlap_penalty: OverlapPenalty[O, COrb, OS, S] | None = field(
        pytree_node=False,
        default=None,
    )
    spin_penalty: SpinPenalty[O, COrb, OS, S] | None = field(
        pytree_node=False,
        default=None,
    )
    reweighting_mode: ReweightingMode = field(
        pytree_node=False,
        default=ReweightingMode.NONE,
    )
    determinant_regularization: float = field(pytree_node=False, default=0.0)
    normalizer_regularization: float = field(pytree_node=False, default=0.0)

    def __post_init__(self):
        if isinstance(self.reweighting_mode, ReweightingMode):
            return
        object.__setattr__(
            self,
            'reweighting_mode',
            ReweightingMode.from_str(self.reweighting_mode),
        )

    def init(self, key: Array, systems: Systems):
        key, subkey = jax.random.split(key)
        params = self.wave_function.init(key, systems.example_input)
        overlap_state = self.overlap_penalty.init() if self.overlap_penalty else None
        assert self.overlap_penalty or self.reweighting_mode == ReweightingMode.NONE, (
            'Sample reweighting is only supported with an overlap penalty.'
        )
        return VMCState(
            params=params,
            grad_norm_ema=EMA.init(jnp.array(0.0, dtype=jnp.float32)),
            optimizer=self.optimizer.init(params),  # type: ignore
            preconditioner=self.preconditioner.init(subkey, params, systems),
            overlap=overlap_state,
            step=jnp.zeros((), dtype=jnp.int32),
            epoch=jnp.ones((), dtype=jnp.int32),
        )

    def init_systems(self, key: Array, systems: S) -> S:
        @shmap(
            in_specs=(REPLICATE_SPEC, systems.partition_spec),
            out_specs=systems.partition_spec,
        )
        def init(key: Array, systems: S):
            key = distribute_keys(key)
            key, subkey = jax.random.split(key)
            systems = self.sampler.init_systems(subkey, systems)
            systems = SmoothData.init_systems(systems)
            if self.overlap_penalty is not None:
                systems = self.overlap_penalty.init_systems(systems)
                systems = systems.set_mol_data(
                    LOG_NORMALIZER_CONSTANTS_KEY,
                    jnp.zeros((systems.n_mols,), dtype=jnp.float64),
                )
            if self.spin_penalty is not None:
                systems = self.spin_penalty.init_systems(systems)
            return systems

        return init(key, systems)

    @jit
    def local_energy(
        self,
        state: VMCState,
        systems: Systems,
        key: Array,
    ) -> tuple[LocalEnergy | ReweightedLocalEnergy, dict[str, Array]]:
        is_energy_reweighted = self.reweighting_mode.energy_reweighted

        local_energy_fn = make_local_energy(
            self.wave_function,
            KineticEnergyOp.FORWARD,
            sample_reweighting=is_energy_reweighted,
        )
        batch_size = systems.electrons.shape[0]
        keys = jax.random.split(key, batch_size)
        memory_scaling_factor = (
            max(systems.n_nuc_by_mol) * max(systems.n_elec_by_mol) ** 2
        )
        folx_batch_size = max(
            1,
            # largest dense jac should not exceed ~3Gb
            360_000 // memory_scaling_factor,
        )
        local_energy_fn = folx.batched_vmap(
            local_energy_fn,
            max_batch_size=folx_batch_size,
            in_axes=(None, systems.electron_vmap, None, 0),
        )

        e_l, e_aux = local_energy_fn(
            state.params,
            systems,
            self.wave_function.reparams(state.params, systems),
            keys,
        )

        if not is_energy_reweighted:
            # e_l: (batch_size, n_mols)
            return e_l, e_aux

        # e_l: (batch_size, n_mols, n_states)
        # the sample_src is segmented in the n_mols dimension, but we want
        # the sample_src segmented into the batch dimension
        def _reshape_tensor(x: Array) -> Array:
            x = unsegment_axis(
                x,
                np.array(systems.mol_id_groups),
                axis=1,
                indices_are_grouped=True,
                num_segments=systems.n_unique_mols,
            )
            return einops.rearrange(
                x,
                'batch_size n_unique_mols sample_src wf_state -> '
                '(sample_src batch_size) n_unique_mols wf_state',
            )[:, systems.mol_id_groups, systems.excitations]

        e_l = _reshape_tensor(e_l)
        e_aux = {k: _reshape_tensor(v) for k, v in e_aux.items()}
        return e_l, e_aux

    @jit
    def determinant_loss_grad(self, state: VMCState[PS], systems: Systems):
        def _loss(params: WaveFunctionParameters):
            def _single_loss(systems: Systems):
                return self.wave_function.wave_function.orbital_module.effective_determinants(
                    systems,
                    self.wave_function.orbitals(params, systems),
                )

            loss, aux_data = jax.vmap(_single_loss, in_axes=(systems.electron_vmap,))(
                systems,
            )
            return pmean(loss.mean()), jax.tree.map(lambda x: x.mean(0), aux_data)

        (det_loss, aux_data), grad = pmean_if_pmap(
            jax.value_and_grad(_loss, has_aux=True)(state.params),
        )

        return tree_mul(grad, self.determinant_regularization), aux_data | {
            'determinant_loss': det_loss * self.determinant_regularization,
        }

    @jit
    def normalizer_ratio_loss_grad(self, state: VMCState[PS], systems: Systems):
        def _loss(params: WaveFunctionParameters):
            logpsis = compute_logpsi(
                self.wave_function,
                params,
                systems,
            )[0]

            n_states = logpsis.shape[-1]

            log_norm_ratios = get_normalizing_constant_ratios(systems)
            i, j = jnp.triu_indices(n_states, k=1)
            pairwise_ratios = log_norm_ratios[:, i, j]
            loss = jnp.mean(pairwise_ratios**2)
            return loss * self.normalizer_regularization

        loss, grad = pmean_if_pmap(
            jax.value_and_grad(_loss, has_aux=False)(state.params),
        )

        return jax.lax.cond(
            is_tree_finite(grad),
            lambda: grad,
            lambda: jax.tree.map(jnp.zeros_like, grad),
        ), {
            'normalizer_ratio_loss': loss,
        }

    @jit
    def mcmc_step(self, key: Array, state: VMCState[PS], systems: Systems):
        @shmap(
            in_specs=(REPLICATE_SPEC, state.partition_spec, systems.partition_spec),
            out_specs=(systems.partition_spec, REPLICATE_SPEC),
        )
        def _mcmc_step(key: Array, state: VMCState[PS], systems: Systems):
            key = distribute_keys(key)
            # Sampling
            key, subkey = jax.random.split(key)
            systems, aux_data = self.sampler(subkey, state.params, systems)
            return systems, aux_data

        return _mcmc_step(key, state, systems)

    @jit
    def eval_step(self, key: Array, state: VMCState[PS], systems: Systems):
        @shmap(
            in_specs=(REPLICATE_SPEC, state.partition_spec, systems.partition_spec),
            out_specs=(systems.partition_spec, REPLICATE_SPEC),
            check_vma=False,
        )
        def _eval_step(key: Array, state: VMCState[PS], systems: Systems):
            assert self.reweighting_mode == ReweightingMode.NONE, (
                'Sample reweighting is not supported in eval step'
            )
            key = distribute_keys(key)
            # Sampling
            key, subkey = jax.random.split(key)
            systems, _ = self.sampler(subkey, state.params, systems)

            # Local energy
            key, subkey = jax.random.split(key)
            e_l, _ = self.local_energy(state, systems, subkey)
            E = pmean(e_l.mean(0))
            E_sum = psum(e_l.sum(0))
            E2_sum = psum((e_l**2).sum(0))

            # Overlap
            overlap, log_density_overlap, log_abs_wf_overlap = None, None, None
            if self.overlap_penalty is not None:
                systems = update_normalizing_constant_ratios(
                    self.wave_function,
                    state.params,
                    systems,
                )[0]
                overlap = self.overlap_penalty.pairwise_overlap(
                    state.params,
                    systems,
                )
                log_density_overlap = compute_log_density_overlap(
                    self.wave_function,
                    state.params,
                    systems,
                )
                log_abs_wf_overlap = compute_log_abs_wf_overlap(
                    self.wave_function,
                    state.params,
                    systems,
                )

            return systems, EvalData(
                E,
                E_sum,
                E2_sum,
                overlap,
                log_density_overlap,
                log_abs_wf_overlap,
            )

        return _eval_step(key, state, systems)

    @jit(donate_argnames=('state', 'systems'))
    def step(self, key: Array, state: VMCState[PS], systems: Systems):
        @shmap(
            in_specs=(REPLICATE_SPEC, state.partition_spec, systems.partition_spec),
            out_specs=(
                state.partition_spec,
                systems.partition_spec,
                REPLICATE_SPEC,
                REPLICATE_SPEC,
            ),
            check_vma=False,
        )
        def _step(key: Array, state: VMCState[PS], systems: Systems):
            key = distribute_keys(key)
            aux_data = {}

            # Sampling
            key, subkey = jax.random.split(key)
            systems, mcmc_aux = self.sampler(subkey, state.params, systems)
            aux_data |= {f'mcmc/{k}': v for k, v in mcmc_aux.items()}

            # Indicates whether E_local is reweighted, i.e. if true has shape (sample_src*walker, n_mols)
            _is_energy_reweighted = self.reweighting_mode.energy_reweighted
            key, subkey = jax.random.split(key)
            raw_e_l, energy_aux = self.local_energy(state, systems, subkey)

            # Compute reweighting factor
            _dummy_reweighting_factor = jnp.ones_like(raw_e_l)
            if self.reweighting_mode.any_reweighted:
                systems, normalizer_ratio_aux = update_normalizing_constant_ratios(
                    self.wave_function,
                    state.params,
                    systems,
                )
                aux_data |= normalizer_ratio_aux
                # If anything's reweighted, we need the weights
                reweighting_factor = compute_reweighting_factor(
                    self.wave_function,
                    state.params,
                    systems,
                )
            else:
                reweighting_factor = _dummy_reweighting_factor

            # Energy only needs weights if it is reweighted
            E_reweighting_factor = (
                reweighting_factor if _is_energy_reweighted else _dummy_reweighting_factor
            )

            # Compute the sample mask
            sample_mask, sample_mask_aux = compute_sample_mask(
                self.wave_function,
                state.params,
                systems,
                self.masking,
                reweighting_factor,
                self.reweighting_mode,
            )
            aux_data |= sample_mask_aux

            # Energy needs an unreweighted mask if its not reweighted (but the mask is)
            E_sample_mask = (
                sample_mask
                if _is_energy_reweighted or not self.reweighting_mode.any_reweighted
                else convert_reweighted_tensor_to_unreweighted(sample_mask, systems)
            )

            # Local energy
            clipped_e_l = self.clipping(
                raw_e_l,
                mask=E_sample_mask,
                reweighting_factor=E_reweighting_factor,
                data_is_reweighted=False,
            )

            # Adding up penalty terms
            auxiliary_grads = jax.tree.map(jnp.zeros_like, state.params)
            dL_dlogpsi = jnp.zeros_like(raw_e_l)

            # Overlap penalty
            overlap_state = state.overlap
            overlap_energy_weights = jnp.ones_like(raw_e_l)
            if self.overlap_penalty is not None:
                assert overlap_state is not None
                (
                    ((dOverlap_dlogpsi, overlap_energy_weights), overlap_aux),
                    systems,
                    overlap_state,
                ) = self.overlap_penalty(
                    state.params,
                    systems,
                    clipped_e_l,
                    reweighting_factor,
                    sample_mask,
                    overlap_state,
                    state.step,
                    reweighting_mode=self.reweighting_mode,
                )

                # In case we only reweight the overlap, we need to compute an aux_grad instead of a cotangent
                if self.reweighting_mode == ReweightingMode.OVERLAP:
                    overlap_grads, *_ = Identity(self.wave_function).apply(
                        state.params,
                        systems,
                        dOverlap_dlogpsi,
                        reweighting_factor,
                        sample_mask,
                        None,
                        jax.tree.map(jnp.zeros_like, state.params),
                        sample_reweighting=True,
                    )
                    aux_data |= {
                        'overlap_grad_norm': tree_squared_norm(overlap_grads) ** 0.5,
                    }
                    auxiliary_grads = tree_add(auxiliary_grads, overlap_grads)
                elif self.reweighting_mode == ReweightingMode.OVERLAP_MEAN_ONLY:
                    dOverlap_dlogpsi = jnp.where(
                        reweighting_factor > 0,
                        dOverlap_dlogpsi / reweighting_factor,
                        0.0,
                    )
                    dOverlap_dlogpsi = convert_reweighted_tensor_to_unreweighted(
                        dOverlap_dlogpsi,
                        systems,
                    )
                    dL_dlogpsi += dOverlap_dlogpsi
                else:
                    dL_dlogpsi += dOverlap_dlogpsi
                aux_data |= {f'overlap/{k}': v for k, v in overlap_aux.items()}

                if self.normalizer_regularization > 0.0:
                    # Compute the normalizer ratio loss and its gradient
                    normalizer_loss_grad, normalizer_aux = (
                        self.normalizer_ratio_loss_grad(
                            state,
                            systems,
                        )
                    )
                    auxiliary_grads = tree_add(auxiliary_grads, normalizer_loss_grad)
                    aux_data |= {
                        f'normalizer_regularization/{k}': v
                        for k, v in normalizer_aux.items()
                    }

            if self.spin_penalty is not None:
                (spin_grads, spin_aux), systems = self.spin_penalty(
                    state.params,
                    systems,
                    state.step,
                )
                auxiliary_grads = tree_add(auxiliary_grads, spin_grads)

                spin_per_mol_metrics = [
                    (k, v) for k, v in spin_aux.items() if v.ndim == 1
                ]
                aux_data |= {k: v for k, v in spin_aux.items() if v.ndim == 0}
                aux_data |= {k: jnp.mean(v) for k, v in spin_per_mol_metrics}
            else:
                spin_per_mol_metrics = []

            # Energy penalty
            dE_dlogpsi = weighted_centering(
                clipped_e_l,
                E_sample_mask,
                E_reweighting_factor,
            )
            dE_dlogpsi *= E_reweighting_factor
            dE_dlogpsi *= overlap_energy_weights
            aux_data |= {'dE_dlogpsi': psum(jnp.sum(dE_dlogpsi**2)) ** 0.5}
            dL_dlogpsi += dE_dlogpsi

            # Determinant regularization
            det_per_mol_metrics = []
            if self.determinant_regularization > 0.0:
                # Compute the determinant loss and its gradient
                det_loss_grad, det_aux = self.determinant_loss_grad(state, systems)
                auxiliary_grads = tree_add(auxiliary_grads, det_loss_grad)
                det_per_mol_metrics = [
                    (f'determinant_regularization/{k}', v)
                    for k, v in det_aux.items()
                    if v.ndim == 1
                ]
                aux_data |= {
                    f'determinant_regularization/{k}': v
                    for k, v in det_aux.items()
                    if v.ndim == 0
                }

            # Preconditioning
            gradient, preconditioner_state, precond_aux = self.preconditioner.apply(
                state.params,
                systems,
                dL_dlogpsi,
                E_reweighting_factor,
                E_sample_mask,
                state.preconditioner,
                auxiliary_grads,
                sample_reweighting=_is_energy_reweighted,
            )
            aux_data |= {f'preconditioner/{k}': v for k, v in precond_aux.items()}

            # Update gradient ema
            grad_norm = tree_squared_norm(gradient) ** 0.5
            gradient_ema = state.grad_norm_ema.update(grad_norm, 0)

            # Tentatively apply update
            updates, opt_state = self.optimizer.update(gradient, state.optimizer)  # type: ignore
            params = optax.apply_updates(state.params, updates)  # type: ignore

            keep_update = jnp.all(
                jnp.array(
                    [
                        (grad_norm < 1.05 * state.grad_norm_ema.value())
                        | (state.step < 1000),  # Exploding grads
                        jnp.isfinite(grad_norm),  # NaN grads
                        weighted_mean(
                            clipped_e_l,
                            E_sample_mask,
                            E_reweighting_factor,
                        ).mean()
                        < 1e-3,  # Positive energies
                        weighted_std(
                            clipped_e_l,
                            E_sample_mask,
                            E_reweighting_factor,
                        ).mean()
                        < 1000,
                    ],
                ),
            )
            # Decide whether to keep the update
            params, opt_state, preconditioner_state = jax.lax.cond(
                keep_update,
                lambda: (params, opt_state, preconditioner_state),  # keep
                lambda: (state.params, state.optimizer, state.preconditioner),  # skip
            )

            # Effective sample size
            masked_weights = jnp.where(sample_mask, reweighting_factor, 0.0)
            weights_sum = jnp.asarray(psum(masked_weights.sum(0)))
            weights_sq_sum = jnp.asarray(psum((masked_weights**2).sum(0)))
            ess_per_mol = jnp.where(
                weights_sq_sum > 0,
                (weights_sum**2) / weights_sq_sum,
                jnp.zeros_like(weights_sq_sum),
            )
            total_samples = jnp.asarray(
                psum(jnp.array(masked_weights.shape[0], dtype=jnp.float32)),
            )
            ess_per_mol = ess_per_mol / jnp.maximum(total_samples, 1.0)

            # Logging
            aux_grad_norm = tree_squared_norm(auxiliary_grads) ** 0.5
            aux_data |= {'aux_grad_norm': aux_grad_norm}
            n_unique_mols = systems.n_unique_mols
            mol_ids = np.asarray(systems.mol_ids)

            E_per_mol = weighted_mean(clipped_e_l, E_sample_mask, E_reweighting_factor)
            E = E_per_mol.mean()
            E_std_per_mol = weighted_std(clipped_e_l, E_sample_mask, E_reweighting_factor)
            E_std = E_std_per_mol.mean()

            E_components_per_mol = {
                key: weighted_mean(val, E_sample_mask, E_reweighting_factor)
                for key, val in energy_aux.items()
            }
            E_components = {key: val.mean() for key, val in E_components_per_mol.items()}
            aux_data |= E_components

            smooth_data = systems.get_mol_data(SMOOTH_DATA_KEY)
            smooth_data = smooth_data.update(SmoothData(E_per_mol, E_std_per_mol))
            systems = systems.set_mol_data(SMOOTH_DATA_KEY, smooth_data)
            smooth_data = smooth_data.value()
            E_per_mol_smooth = smooth_data.energy
            E_std_per_mol_smooth = smooth_data.std
            E_smooth = E_per_mol_smooth.mean()
            E_std_smooth = E_std_per_mol_smooth.mean()

            ground_state_energy = jax.ops.segment_min(
                E_per_mol,
                mol_ids,
                num_segments=n_unique_mols,
            )
            excitation_energy_per_mol = E_per_mol - ground_state_energy[mol_ids]
            ground_state_energy_smooth = jax.ops.segment_min(
                E_per_mol_smooth,
                mol_ids,
                num_segments=n_unique_mols,
            )
            excitation_energy_per_mol_smooth = (
                E_per_mol_smooth - ground_state_energy_smooth[mol_ids]
            )

            reweighting_factor_mean = pmean_if_pmap(reweighting_factor.mean())
            reweighting_factor_max = pmax_if_pmap(reweighting_factor.max())

            aux_data |= {
                'E': E,
                'E_std': E_std,
                'E_smooth': E_smooth,
                'E_std_smooth': E_std_smooth,
                'grad_norm': grad_norm,
                'ESS': ess_per_mol.mean(),
                'reweighting_factor_mean': reweighting_factor_mean,
                'reweighting_factor_max': reweighting_factor_max,
            }

            # Log per mol data
            metrics = [
                ('E', E_per_mol),
                ('E_std', E_std_per_mol),
                ('E_smooth', E_per_mol_smooth),
                ('E_std_smooth', E_std_per_mol_smooth),
                ('ESS', ess_per_mol),
                ('excitation_energy', excitation_energy_per_mol),
                ('excitation_energy_smooth', excitation_energy_per_mol_smooth),
                *det_per_mol_metrics,
                *spin_per_mol_metrics,
            ]

            for key, array in metrics:
                unsegmented = unsegment_axis(array, mol_ids)
                # mean only for safety; values should be scalar
                aux_data |= {
                    f'{key}/structure_{i}/state_{j}': unsegmented[i, j].mean()
                    for i in np.unique(mol_ids)
                    for j in range(unsegmented.shape[1])
                }

            return (
                state.replace(
                    params=params,
                    grad_norm_ema=gradient_ema,
                    optimizer=opt_state,
                    preconditioner=preconditioner_state,
                    overlap=overlap_state,
                ),
                systems,
                aux_data,
                keep_update,
            )

        return _step(key, state, systems)
