"""
Data structures for Gaussian-Type-Orbital (GTO) basis sets.

- Primitives P:
    A single uncontracted Gaussian function
    N (x-X_A)^{l_x} (y-Y_A)^{l_y} (z-Z_A)^{l_z} exp(-alpha |r-R_A|²)
    defined by its exponent alpha, angular momentum (l_x,l_y,l_z), and center R_A.

- Basis functions B:
    A linear combination of primitives with fixed contraction coefficients d_p.
    These contracted functions form the actual basis used in the molecular
    orbital expansion.

- Radial primitives RP:
    Unique Gaussian factors identified by their center R_A, total angular momentum l,
    and exponent alpha, but independent of the angular decomposition (l_x,l_y,l_z) or m.
    Each radial primitive represents the shared radial dependence across all angular
    components of a shell:
        R_{l,alpha}(r_A) = N_l(alpha) r_A^l exp(-alpha r_A^2)
    where N_l(alpha) is the appropriate normalization. All primitives within the same
    (R_A, l, alpha) tuple share this radial primitive.

- Shells S:
    A group of basis functions that only differ in the magnetic quantum number m.
    i.e. a group of basis functions on a common center sharing the same angular
    momentum quantum number l and set of radial primitives. For example,
    a p-shell contains the px, py, pz functions and a d-shell contains the dxz, dyz, dz², dx²-y²,
    and dxy functions with a shared radial component.

Additional Notes:
    We explicitly store the primitives of each basis function separately, such that
    these are never explicitly shared between basis functions. We then directly store
    (normalization factors * contraction coefficients) as primitive coefficients.
"""

import warnings
from dataclasses import dataclass as host_dataclass
from hashlib import sha256
from typing import Callable, Dict, List, NamedTuple, Set, Tuple

import jax.numpy as jnp
import numpy as onp
import pyscf
import pyscf.lib.exceptions
from flax.struct import dataclass
from pyscf import gto

from egxc.utils.constants import L_MAX
from egxc.utils.typing import (
    PRECISION,
    FloatRP,
    NpBoolLxT,
    NpFloatG,
    NpFloatRP,
    NpUIntA,
    NpUIntG,
    NpUIntRP,
    NpUIntS,
    NpUIntT,
    UIntRP,
)


@host_dataclass(frozen=True)
class AngularMomentumIndexing:
    """
    Compile time static indices for each angular momentum of varying lengths
    """

    _s_indices: NpUIntT  # uint16
    _p_indices: NpUIntT  # uint16
    _d_indices: NpUIntT  # uint16
    _f_indices: NpUIntT  # uint16

    @classmethod
    def from_angular_momentum_array(
        cls, angular_momenta: NpUIntT
    ) -> 'AngularMomentumIndexing':
        indices = (
            onp.where(angular_momenta == i)[0].astype(onp.uint16)
            for i in range(L_MAX + 1)
        )
        return cls(*indices)

    @classmethod
    def from_masks(cls, masks: NpBoolLxT) -> 'AngularMomentumIndexing':
        indices = tuple(
            onp.where(masks[i])[0].astype(onp.uint16) for i in range(masks.shape[0])
        )
        indices += (onp.array([], dtype=onp.uint16),) * (L_MAX + 1 - masks.shape[0])
        return cls(*indices)

    def pad(self, s: int, p: int, d: int, f: int) -> 'AngularMomentumIndexing':  # type: ignore # TODO: pad
        print(
            self._s_indices.shape,
            self._p_indices.shape,
            self._d_indices.shape,
            self._f_indices.shape,
        )  # TODO: pad
        # return AngularMomentumIndexing(
        #     onp.pad(self._s_indices, (0, s)),
        #     onp.pad(self._p_indices, (0, p)),
        #     onp.pad(self._d_indices, (0, d)),
        #     onp.pad(self._f_indices, (0, f)),
        # )

    def __getitem__(self, angular_momentum: int) -> NpUIntT:
        match angular_momentum:
            case 0:
                return self._s_indices
            case 1:
                return self._p_indices
            case 2:
                return self._d_indices
            case 3:
                return self._f_indices
            case _:
                raise ValueError(f'Invalid angular momentum: {angular_momentum}')


@host_dataclass(frozen=True)
class PreloadedGTOBasis:
    """
    Preloaded basis parameters prepared by cpu-workers.
    """

    hash_value: int
    max_angular_momentum_by_atom_idx: NpUIntA  # uint8
    shell_atom_indices: NpUIntS  # uint8
    angular_shell_indices: AngularMomentumIndexing
    angular_basis_indices: AngularMomentumIndexing
    radial_primitive_atom_indices: NpUIntRP
    radial_primitive_coeffs: NpFloatRP
    radial_primitive_exponents: NpFloatRP
    radial_primitive_shell_indices: NpUIntRP
    num_basis_fns: int

    # def __post_init__(self):
    #     # print all shape of the arrays
    #     A = self.max_angular_momentum_by_atom_idx.shape[0]
    #     P = self.radial_primitive_atom_indices.shape[0]
    #     S = self.shell_atom_indices.shape[0]
    #     B = self.num_basis_fns
    #     print(f'A: {A}, P: {P}, S: {S}, B: {B}')
    #     self.angular_shell_indices.pad(s=A, p=A, d=A, f=A)
    #     self.angular_basis_indices.pad(s=B, p=B, d=B, f=B)

    #     print('#########################  padded')

    def pad(self, basis):
        pass  # TODO: pad the basis functions


@host_dataclass(frozen=True)
class GTOCompileStatics:
    """
    Compile time static information about the basis functions, and indexing.
    Optimized for small compile static memory.
    """

    hash_value: int
    max_angular_momentum_by_atom_idx: NpUIntA  # uint8
    shell_atom_indices: NpUIntS  # uint8
    angular_shell_indices: AngularMomentumIndexing  # indexes the shell dimension S < B
    angular_basis_indices: AngularMomentumIndexing  # indexes the basis dimension B
    num_basis_fns: int

    def __hash__(self) -> int:
        """
        The hash value is a sha256 hash of the (sorted!) atom indices, since molecules consisting
        of the same atoms have identical basis functions, but different centers.
        """
        return self.hash_value

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, GTOCompileStatics):
            return False
        return self.hash_value == other.hash_value

    @classmethod
    def from_preloaded(cls, preloaded: PreloadedGTOBasis) -> 'GTOCompileStatics':
        return cls(
            preloaded.hash_value,
            preloaded.max_angular_momentum_by_atom_idx,
            preloaded.shell_atom_indices,
            preloaded.angular_shell_indices,
            preloaded.angular_basis_indices,
            preloaded.num_basis_fns,
        )


@dataclass
class RadialPrimitives:
    coeffs: FloatRP
    exponents: FloatRP
    shell_indices: UIntRP  # indicates to which shell the primitive belongs
    atom_indices: UIntRP  # indicates on which atom the primitive is centered

    @property
    def count(self) -> int:
        return self.coeffs.shape[0]


@dataclass
class GTOBasis:
    """
    A dataclass holding the gaussian type basis-set constants for atomic orbital evaluation.

    This class encapsulates both compile-time static information and runtime primitive data
    needed for efficient JAX-compiled atomic orbital evaluation.

    Attributes:
        compile_statics: Static information about basis function structure (angular momenta,
            atom indices, etc.) that can be used as static arguments in JAX compilation
        radial_primitives: Runtime arrays containing radial primitive coefficients, exponents, and
            indices for Gaussian basis function evaluation

    The class provides a factory method `from_preloaded` to construct instances from
    preprocessed basis set data, converting numpy arrays to JAX arrays with appropriate
    dtypes for memory efficiency.
    """

    compile_statics: GTOCompileStatics
    radial_primitives: RadialPrimitives

    @classmethod
    def from_preloaded(cls, preloaded: PreloadedGTOBasis) -> 'GTOBasis':
        return cls(
            compile_statics=GTOCompileStatics.from_preloaded(preloaded),
            radial_primitives=RadialPrimitives(
                coeffs=jnp.asarray(preloaded.radial_primitive_coeffs),
                exponents=jnp.asarray(preloaded.radial_primitive_exponents),
                shell_indices=jnp.asarray(preloaded.radial_primitive_shell_indices),
                atom_indices=jnp.asarray(
                    preloaded.radial_primitive_atom_indices, dtype=jnp.uint16
                ),
            ),
        )


GTOPreloader = Callable[[NpUIntA], PreloadedGTOBasis]


def get_gto_preloader(
    basis_string: str, unique_elements: Set[int]
) -> Tuple[int, GTOPreloader]:
    """
    Returns a tuple of the max angular momentum and a factory function that constructs a GTOBasis from the atomic numbers.
    """

    if max(unique_elements) >= 36:
        warnings.warn(
            'Only elements including up to period 4 are supported'
            + f'got max_Z={max(unique_elements)}'
            + 'mind that relativistic effects are not considered'
        )

    class AtomicBasisSetData(NamedTuple):  # On host / CPU
        num_shells: int
        max_angular_momentum: int
        angular_momenta: List[int]
        primitive_coeffs: List[NpFloatG]
        primitive_exponents: List[NpFloatG]
        primitive_shell_indices: List[NpUIntG]

    def get_basis_set_coeff_from_pyscf(z: int, basis_string: str) -> AtomicBasisSetData:
        def atm(z: int) -> gto.Mole:
            """
            generate a pyscf molecule with a single atom of atomic number Z
            """
            try:
                return gto.M(atom=[(z, (0, 0, 0))], basis=basis_string, spin=z % 2)
            except pyscf.lib.exceptions.BasisNotFoundError:
                warnings.warn(
                    f'Warning: basis {basis_string} not found for Z={z} in pyscf'
                )
                raise ValueError(f'Basis {basis_string} not found for Z={z} in pyscf')

        atom = atm(z)
        num_shells = atom.nbas

        angular_momenta = []
        primitive_coeffs = []
        primitive_exponents = []
        primitive_shell_indices = []
        for i in range(num_shells):
            exp = atom.bas_exp(i)
            primitive_exponents.append(exp)
            L = atom.bas_angular(i)
            angular_momenta.append(L)
            norm = (2 * exp / onp.pi) ** (3 / 4) * (8 * exp) ** (L / 2)
            ctr = atom.bas_ctr_coeff(i).reshape(-1)
            primitive_coeffs.append(ctr * norm)
            primitive_shell_indices.extend(
                onp.arange(len(ctr)) + len(primitive_shell_indices)
            )

        return AtomicBasisSetData(
            num_shells,  # number of shells
            max(angular_momenta),
            angular_momenta,
            primitive_coeffs,
            primitive_exponents,
            primitive_shell_indices,
        )

    lookup_table: Dict[int, AtomicBasisSetData] = {}

    max_angular_momentum = 0
    for z in unique_elements:
        lookup_table[z] = get_basis_set_coeff_from_pyscf(z, basis_string)
        max_angular_momentum = max(
            max_angular_momentum, lookup_table[z].max_angular_momentum
        )

    assert max_angular_momentum <= L_MAX, 'Only up to f-orbitals are supported'

    def out_fn(
        atom_z: NpUIntA,
    ) -> PreloadedGTOBasis:
        # Build VecBasisFns for the provided atomic numbers (zeros are treated as padding)
        atom_indices = []
        angular_momenta: List[int] = []
        primitive_atom_indices: List[int] = []
        primitive_angular_momenta: List[int] = []
        primitive_coeffs: List[NpFloatG] = []
        primitive_exponents: List[NpFloatG] = []
        primitive_shell_indices: List[int] = []

        # mask out padded atoms
        atom_z = atom_z[atom_z != 0]

        assert onp.all(atom_z[:-1] <= atom_z[1:]), 'atom_indices must be increasing'
        hash_value = int(sha256(atom_z.tobytes()).hexdigest(), 16)

        shell_counter = 0
        atom_counter = 0
        max_angular_momentum_by_atom_idx = []
        for z in list(atom_z):
            z_int = int(z)
            if z_int == 0:
                continue
            current_lookup = lookup_table[z_int]
            max_angular_momentum_by_atom_idx.append(current_lookup.max_angular_momentum)
            for local_shell_idx in range(current_lookup.num_shells):
                L = current_lookup.angular_momenta[local_shell_idx]
                atom_indices.append(atom_counter)
                angular_momenta.append(L)

                coeffs = current_lookup.primitive_coeffs[local_shell_idx]
                exps = current_lookup.primitive_exponents[local_shell_idx]
                nprim = int(coeffs.shape[0])

                primitive_atom_indices.extend([atom_counter] * nprim)
                primitive_angular_momenta.extend([L] * nprim)

                primitive_coeffs.append(coeffs)
                primitive_exponents.append(exps)
                primitive_shell_indices.extend([shell_counter] * nprim)

                shell_counter += 1
            atom_counter += 1

        a_indices = onp.asarray(atom_indices, dtype=onp.uint16)
        prim_a_indices = onp.asarray(primitive_atom_indices, dtype=onp.uint16)
        prim_shell_indices = onp.asarray(primitive_shell_indices, dtype=onp.uint16)
        assert onp.all(a_indices[1:] >= a_indices[:-1]), 'atom indices must be increasing'
        assert onp.all(prim_a_indices[1:] >= prim_a_indices[:-1]), (
            'primitive atom indices must be increasing'
        )
        assert onp.all(prim_shell_indices[1:] >= prim_shell_indices[:-1]), (
            'primitive shell indices must be increasing'
        )

        # angular momentum indexing
        shell_angular_momentum_arr = onp.asarray(angular_momenta, dtype=onp.uint8)
        shell_angular_multiplicities_arr = 2 * shell_angular_momentum_arr + 1
        shell_angular_momentum_masks = (
            shell_angular_momentum_arr[None, :]
            == onp.arange(max_angular_momentum + 1)[:, None]
        )
        angular_shell_indices = AngularMomentumIndexing.from_masks(
            shell_angular_momentum_masks
        )
        shells_to_basis_masks = onp.repeat(
            shell_angular_momentum_masks, shell_angular_multiplicities_arr, axis=1
        )
        angular_basis_indices = AngularMomentumIndexing.from_masks(shells_to_basis_masks)

        vec = PreloadedGTOBasis(
            hash_value=hash_value,
            max_angular_momentum_by_atom_idx=onp.asarray(
                max_angular_momentum_by_atom_idx, dtype=onp.uint8
            ),
            shell_atom_indices=a_indices,
            angular_shell_indices=angular_shell_indices,
            angular_basis_indices=angular_basis_indices,
            radial_primitive_atom_indices=prim_a_indices,
            radial_primitive_coeffs=onp.concatenate(
                primitive_coeffs, dtype=PRECISION.basis
            ),
            radial_primitive_exponents=onp.concatenate(
                primitive_exponents, dtype=PRECISION.basis
            ),
            radial_primitive_shell_indices=prim_shell_indices,
            num_basis_fns=int(shell_angular_multiplicities_arr.sum()),
        )
        return vec

    return max_angular_momentum, out_fn
