import functools
from typing import TYPE_CHECKING, NamedTuple

import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.struct import PyTreeNode
from jaxtyping import Array, Float

from neural_pfaffian.linalg import (
    antisymmetric_block_diagonal,
    skewsymmetric_quadratic,
    skewsymmetric_quadratic_with_antisymmetric_block_identity,
    slog_pfaffian,
)
from neural_pfaffian.nn.antisymmetrizer.pfaffian import (
    PerNucOrbitals,
    PfaffianOrbitals,
    PfaffianPretrainingState,
    max_orbitals,
    orbital_mask,
    pfaffian_pretraining_loss,
)
from neural_pfaffian.nn.envelope import Envelope
from neural_pfaffian.nn.module import ParamMeta, ParamTypes, ReparamModule
from neural_pfaffian.nn.wave_function import AntisymmetrizerP
from neural_pfaffian.systems import Systems, SystemsWithPretrainTarget
from neural_pfaffian.utils import EMA, itemgetter
from neural_pfaffian.utils.jax_utils import vectorize, vmap
from neural_pfaffian.utils.segment_utils import unsegment_axis
from neural_pfaffian.utils.tree_utils import tree_stack

if TYPE_CHECKING:
    from collections.abc import Sequence


class LowRankPfaffianOrbitals(PyTreeNode):
    orbitals: Float[Array, 'mols elec orbitals']
    antisymmetrizer: Float[Array, 'mols orbitals orbitals']
    antisymmetrizer_updates: Float[Array, 'mols det orbitals rank']
    orb_A_orb_product: Float[Array, 'mols elec elec']
    updates: Float[Array, 'mols det elec rank']

    def to_pfaffian_orbitals(self):
        A = jnp.expand_dims(self.antisymmetrizer, axis=-3)
        updates = self.antisymmetrizer_updates
        rank = updates.shape[-1]
        J = antisymmetric_block_diagonal(rank // 2, updates.dtype)
        As = A + skewsymmetric_quadratic(updates, J)
        orbitals = jnp.expand_dims(self.orbitals, axis=-3)
        orb_A_orb_product = jnp.expand_dims(self.orb_A_orb_product, axis=-3)
        orb_A_orb_product += skewsymmetric_quadratic_with_antisymmetric_block_identity(
            self.updates,
        )
        return PfaffianOrbitals(
            orbitals=orbitals,
            antisymmetrizer=As,
            orb_A_orb_product=orb_A_orb_product,
        )


class LowRankCoreOrbitals(NamedTuple):
    up_orbs: list[Float[Array, 'n_mol_per_group electrons orbitals 1']]
    down_orbs: list[Float[Array, 'n_mol_per_group electrons orbitals 1']]
    up_down_orbs: list[Float[Array, 'n_mol_per_group electrons orbitals 1']]
    down_up_orbs: list[Float[Array, 'n_mol_per_group electrons orbitals 1']]
    fill_vec: Float[Array, 'nuclei 2*max_orb']
    fill_vec_meta: ParamMeta


class UnexcitedLRPfaffian(ReparamModule):
    orb_per_charge: dict[str, int]
    envelope: Envelope

    @nn.compact
    def __call__(self, systems: Systems, elec_embeddings: Float[Array, 'electrons dim']):
        max_orb = max_orbitals(self.orb_per_charge)
        up_orbs = PerNucOrbitals(
            1,
            self.orb_per_charge,
            self.envelope.copy(pi_init=1.0),
            'orb',
            None,
        )(systems, elec_embeddings)
        down_orbs = PerNucOrbitals(
            1,
            self.orb_per_charge,
            self.envelope.copy(pi_init=1.0),
            'orb',
            None,
        )(systems, elec_embeddings)
        up_down_orbs = PerNucOrbitals(
            1,
            self.orb_per_charge,
            self.envelope.copy(pi_init=1.0),
            'orb',
            None,
        )(systems, elec_embeddings)
        down_up_orbs = PerNucOrbitals(
            1,
            self.orb_per_charge,
            self.envelope.copy(pi_init=1e-3),
            'orb',
            None,
        )(systems, elec_embeddings)
        # If n_elec is odd, we need an extra orbital
        fill_vec, fill_vec_meta = self.reparam(
            'fill_coeffs',
            jax.nn.initializers.normal(1, dtype=jnp.float32),
            (systems.n_nuc, 2 * max_orb),
            param_type=ParamTypes.NUCLEI,
        )
        return LowRankCoreOrbitals(
            up_orbs,
            down_orbs,
            up_down_orbs,
            down_up_orbs,
            fill_vec,
            fill_vec_meta,
        )


class ExcitedLRPfaffian(ReparamModule):
    determinants: int
    orb_per_charge: dict[str, int]
    rank: int

    max_num_states: int | None = None

    @nn.compact
    def __call__(self, systems: Systems, core_orbitals: LowRankCoreOrbitals):
        assert self.max_num_states is not None and self.max_num_states > 0, (
            'max_num_states must be set'
        )
        max_orb = max_orbitals(self.orb_per_charge)

        A, A_meta = self.reparam(
            'antisymmetrizer',
            jax.nn.initializers.normal(1, dtype=jnp.float32),
            (systems.n_nn, 2, max_orb, max_orb),
            param_type=ParamTypes.NUCLEI_NUCLEI,
            bias=True,
        )

        A_updates, A_updates_meta = self.reparam(
            'antisymmetrizer_updates',
            jax.nn.initializers.normal(1, dtype=jnp.float32),
            (
                systems.n_nuc,  # 200
                self.max_num_states,  # 10
                max_orb,  # 16
                2,  # 2
                self.rank,  # 4
                self.determinants,  # 16
            ),
            param_sharing_axis=-1,
            param_type=ParamTypes.NUCLEI,
        )

        up_orbs, down_orbs, up_down_orbs, down_up_orbs, fill_vec, fill_vec_meta = (
            core_orbitals
        )
        result: list[LowRankPfaffianOrbitals] = []
        for up, down, up_down, down_up, fill, A, A_updates, (
            spins,
            charges,
        ), excitation in zip(
            up_orbs,
            down_orbs,
            up_down_orbs,
            down_up_orbs,
            systems.group(fill_vec, fill_vec_meta.param_type.value.chunk_fn),
            systems.group(A, A_meta.param_type.value.chunk_fn),
            systems.group(A_updates, A_updates_meta.param_type.value.chunk_fn),
            systems.unique_spins_and_charges,
            systems.grouped_excitations,
            strict=True,
        ):
            n_mol, n_elec, n_up, n_nuc = up.shape[0], sum(spins), spins[0], len(charges)
            orb_mask = orbital_mask(self.orb_per_charge, charges)
            full_mask = np.repeat(orb_mask, 2)
            up, down, up_down, down_up = (
                x.reshape(n_mol, n_elec, -1) for x in (up, down, up_down, down_up)
            )  # squeezes the n_det dim

            @vmap  # vmap over different molecules
            def _orbitals(
                up: Array,
                down: Array,
                up_down: Array,
                down_up: Array,
                A: Array,
                A_updates: Array,
                fill: Array,
                excitation: Array,
            ):
                # construct full orbitals
                uu, dd, ud, du = up[:n_up], down[n_up:], up_down[:n_up], down_up[n_up:]
                orbitals = jnp.block([[uu, ud], [du, dd]])  # (n_elec, 2*n_orbs)
                # Pad additional orbital if n_elec is odd
                fill = fill.reshape(1, -1)[:, full_mask]
                orbitals = jnp.concatenate([orbitals, fill], axis=0)
                if n_elec % 2 == 0:
                    orbitals = orbitals[:n_elec]

                # A: (2*n_orbs, 2*n_orbs)
                A = einops.rearrange(
                    A,
                    '(n1 n2) two o1 o2 -> two (n1 o1) (n2 o2)',
                    n1=n_nuc,
                )
                A = A[:, orb_mask, :][:, :, orb_mask]  # remove unused orbitals
                A_offdiag, A_diag = A[0], A[1]
                A_uu, A_dd = jnp.triu(A_diag), jnp.tril(A_diag)
                A_uu, A_dd = (A_uu - A_uu.mT), (A_dd - A_dd.mT)
                A = jnp.block([[A_uu, A_offdiag], [-A_offdiag.mT, A_dd]])
                A /= A.shape[-1]

                # Product
                orb_A_orb_product = skewsymmetric_quadratic(
                    orbitals.astype(jnp.float64),
                    A.astype(jnp.float64),
                )

                # Updates

                # Select the updates for the current state
                A_updates = A_updates[:, excitation, ...]
                A_updates = einops.rearrange(
                    A_updates,
                    'nuc orb two rank det -> det (two nuc orb) rank',
                )[:, full_mask]
                A_updates /= jnp.sqrt(A_updates.shape[1])
                updates = jnp.einsum(
                    'no,dor->dnr',
                    orbitals.astype(jnp.float64),
                    A_updates.astype(jnp.float64),
                )
                return LowRankPfaffianOrbitals(
                    orbitals,
                    A,
                    A_updates,
                    orb_A_orb_product,
                    updates,
                )

            result.append(
                _orbitals(up, down, up_down, down_up, A, A_updates, fill, excitation),
            )
        return result


class LowRankPfaffian(
    ReparamModule,
    AntisymmetrizerP[
        list[LowRankPfaffianOrbitals],
        LowRankCoreOrbitals,
        EMA[PfaffianPretrainingState],
    ],
):
    determinants: int
    orb_per_charge: dict[str, int]
    envelope: Envelope
    rank: int

    hf_match_steps: int
    hf_match_lr: float
    hf_match_orbitals: float
    hf_match_antisymmetrizer: float
    hf_match_ema: float
    hf_match_init_bias: float

    max_num_states: int | None = None

    def setup(self):
        self._unexcited_lr_pfaffian = UnexcitedLRPfaffian(
            self.orb_per_charge,
            self.envelope,
        )
        self._excited_lr_pfaffian = ExcitedLRPfaffian(
            self.determinants,
            self.orb_per_charge,
            self.rank,
            self.max_num_states,
        )

    def __call__(self, systems: Systems, elec_embeddings: Float[Array, 'electrons dim']):
        assert self.max_num_states is not None and self.max_num_states > 0, (
            'max_num_states must be set'
        )

        core_orbitals = self._unexcited_lr_pfaffian(systems, elec_embeddings)
        return self._excited_lr_pfaffian(systems, core_orbitals)

    def core_orbitals(
        self,
        systems: Systems,
        elec_embeddings: Float[Array, 'electrons dim'],
    ):
        return self._unexcited_lr_pfaffian(systems, elec_embeddings)

    def apply_excitation(self, systems: Systems, core_orbitals):
        return self._excited_lr_pfaffian(systems, core_orbitals)

    def to_slog_psi(self, systems: Systems, orbitals: list[LowRankPfaffianOrbitals]):
        dtype = systems.electrons.dtype
        signs, logpsis = [], []

        for orb in orbitals:
            sign, logpsi = slog_pfaffian(
                _updated_antisymmetrizer(
                    orb.orb_A_orb_product[:, None],
                    orb.updates,
                ).astype(jnp.float64),
            )
            logpsi, sign = jax.nn.logsumexp(logpsi, axis=1, b=sign, return_sign=True)
            signs.append(sign)
            logpsis.append(logpsi)
        order = systems.inverse_unique_indices
        sign = jnp.concatenate(signs)[order]
        log_psi = jnp.concatenate(logpsis)[order]
        return sign.astype(jnp.int32), log_psi.astype(dtype)

    def effective_determinants(
        self,
        systems: Systems,
        orbitals: list[LowRankPfaffianOrbitals],
    ) -> tuple[Float[Array, ''], dict[str, Float[Array, ' n_mols']]]:
        perplexities, mins, maxs = [], [], []
        determinant_losses = []
        for orb in orbitals:  # type:ignore
            # Add dimension for the number of determinants
            logpsi = slog_pfaffian(
                _updated_antisymmetrizer(
                    orb.orb_A_orb_product[:, None],
                    orb.updates,
                ).astype(jnp.float64),
            )[1]

            norm_logpsi = logpsi - jax.nn.logsumexp(logpsi, axis=1, keepdims=True)

            # Metrics
            mininmum = jnp.min(jnp.exp(norm_logpsi), axis=1) * logpsi.shape[1]
            mins.append(mininmum)

            maximum = jnp.max(jnp.exp(norm_logpsi), axis=1) * logpsi.shape[1]
            maxs.append(maximum)

            entropy = -jnp.sum(jnp.exp(norm_logpsi) * norm_logpsi, axis=1)
            perplexity = jnp.exp(entropy)
            perplexities.append(perplexity)

            # Loss
            determinant_losses.append(-jnp.sum(norm_logpsi, axis=1))

        order = systems.inverse_unique_indices

        return jnp.mean(jnp.concatenate(determinant_losses)), {
            'perplexity': jnp.concatenate(perplexities)[order],
            'min': jnp.concatenate(mins)[order],
            'max': jnp.concatenate(maxs)[order],
        }

    def match_hf_orbitals(
        self,
        systems: SystemsWithPretrainTarget,
        orbitals: list[LowRankPfaffianOrbitals],  # grouped by molecules
    ):
        state: Sequence[EMA[PfaffianPretrainingState]] = systems.cache
        targets = systems.targets

        loss_fn = functools.partial(
            pfaffian_pretraining_loss,
            orb_weight=self.hf_match_orbitals,
            antisym_weight=self.hf_match_antisymmetrizer,
            learning_rate=self.hf_match_lr,
            steps=self.hf_match_steps,
            ema=self.hf_match_ema,
        )

        out_state: Sequence[EMA[PfaffianPretrainingState]] = []
        loss = jnp.zeros((), dtype=jnp.float32)
        n_diff_mols = jnp.zeros((), dtype=jnp.int32)

        for idx, pfaff_orbs in zip(systems.unique_indices, orbitals, strict=True):
            getter = itemgetter(*idx)
            target_i, state_i, mol_ids_i = (
                getter(targets),
                getter(state),
                getter(systems.mol_ids),
            )

            n_segments = len(set(mol_ids_i))
            n_diff_mols += n_segments
            segment_size = len(mol_ids_i) // n_segments
            unsegment = functools.partial(
                unsegment_axis,
                segment_ids=np.array(mol_ids_i),
                indices_are_grouped=False,
                num_segments=n_segments,
            )

            # Stack the molecules in the first dimension
            target_i = tree_stack(*target_i)
            target_i = jax.tree.map(unsegment, target_i)

            state_i = tree_stack(*state_i)
            state_i = jax.tree.map(unsegment, state_i)
            # Share the same orbital rotation for all molecules
            state_i = jax.tree.map(lambda x: x[:, 0, ...], state_i)

            pfaff_orbs = pfaff_orbs.to_pfaffian_orbitals()
            # for pfaff_orbs, we expect to the see the molecules in the -4 dim. Thus, we should move it to the front
            pfaff_orbs = jax.tree.map(lambda x: jnp.moveaxis(x, -4, 0), pfaff_orbs)
            pfaff_orbs = jax.tree.map(unsegment, pfaff_orbs)

            # Matching
            loss_i, state_i = jax.vmap(loss_fn)(
                target_i.molecular_orbital_permutation,
                target_i.molecular_orbitals,
                pfaff_orbs,
                state_i,
            )
            loss += loss_i.sum()

            # out_state is now sorted by the unique indices and not by the original batch!
            for i in range(n_segments):
                out_state.extend(
                    [jax.tree.map(lambda x, i=i: x[i], state_i)] * segment_size,
                )
        # invert the order of the unique indices
        out_state = itemgetter(*systems.inverse_unique_indices)(out_state)
        return (
            loss / n_diff_mols,
            list(out_state),
            {},
        )

    def init_systems(self, key: Array, systems: SystemsWithPretrainTarget):
        states: list[EMA[PfaffianPretrainingState]] = []
        for sub_sys in systems.sub_configs:
            n_el = sub_sys.n_elec
            n_orbs = (
                orbital_mask(self.orb_per_charge, tuple(sub_sys.flat_charges)).sum() * 2
            )
            target = sub_sys.targets[0]
            n_mo_orbs = target.molecular_orbitals.shape[-1]

            # Check for truncation of active orbitals
            if n_orbs < n_mo_orbs:
                # only works unjitted
                spin_pf_orbs = n_orbs // 2
                spin_mo_orbs = n_mo_orbs // 2
                truncation = list(range(spin_pf_orbs, spin_mo_orbs)) + list(
                    range(spin_mo_orbs + spin_pf_orbs, 2 * spin_mo_orbs),
                )
                assert jnp.all(
                    target.molecular_orbital_permutation[..., truncation, :] == 0,
                ), 'Active orbitals must not be truncated'
            if n_el % 2 == 1:
                n_el += 1
                assert jnp.all(target.molecular_orbital_permutation[..., -1, :] == 0), (
                    'Active orbital would be replaced by dummy orbital'
                )

            state = EMA[PfaffianPretrainingState].init(
                PfaffianPretrainingState(
                    orb_rotation=jnp.zeros((n_el, n_el), dtype=jnp.float32),
                    orb_matching=jnp.zeros((2, n_orbs, n_orbs), dtype=jnp.float32),
                ),
                initial_bias_strength=self.hf_match_init_bias,
            )
            states.append(state)
        return systems.replace(cache=tuple(states))


@vectorize(signature='(n,n),(n,r)->(n,n)')
def _updated_antisymmetrizer(
    orb_A_orb_product: Float[Array, 'elec elec'],
    update: Float[Array, 'elec rank'],
):
    """Apply the antisymmetrizer updates to the orbital product."""
    return orb_A_orb_product + skewsymmetric_quadratic_with_antisymmetric_block_identity(
        update,
    )
