"""
Container for molecular structures taken from
H. Helal's and A Fitzgibbon's
"MESS: Modern Electronic Structure Simulations" package
https://arxiv.org/abs/2406.03121
"""

from typing import List, Tuple

import jax
import jax.numpy as jnp
from flax.struct import dataclass
from periodictable import elements
from pyscf import gto

from egxc.systems.preload import PreloadGrid, PreloadSystem
from egxc.utils.constants import ANGSTROM_TO_BOHR
from egxc.utils.linalg import transformation_matrix
from egxc.utils.typing import (
    PRECISION,
    Bool2xB,
    BoolA,
    BoolB,
    Float1,
    FloatAx3,
    FloatBxB,
    FloatBxBxBxB,
    FloatN,
    FloatNx3,
    FloatNxB,
    FloatNxBx3,
    FloatQxBxB,
    IntA,
    UInt1,
    UIntB,
)


@dataclass
class FockTensors:
    """
    Tensors that are constant for a given Structure and basis set which are
    used to calculate the Fock matrix.
    """

    basis_mask: BoolB
    core_hamiltonian: FloatBxB
    electron_repulsion_tensor: FloatQxBxB | FloatBxBxBxB
    occupancies: UIntB | Bool2xB
    overlap: FloatBxB
    diagonal_overlap: FloatBxB
    eri_sr_tensor: FloatQxBxB | FloatBxBxBxB | None = None
    eri_lr_tensor: FloatQxBxB | FloatBxBxBxB | None = None

    @property
    def ert(self) -> FloatQxBxB | FloatBxBxBxB:
        # abbreviation / alias for electron_repulsion_tensor
        return self.electron_repulsion_tensor


@dataclass
class Grid:
    """
    Quadrature grid points, weights for numerical integration and atomic orbitals evaluated at grid points.s
    """

    coords: FloatNx3
    weights: FloatN
    aos: FloatNxB
    grad_aos: FloatNxBx3 | None

    @classmethod
    def create(
        cls, coords: FloatNx3, weights: FloatN, aos: FloatNxB, grad_aos: FloatNxBx3 | None
    ) -> 'Grid':
        mask = weights == 0
        aos = jnp.where(mask[:, None], 1, aos)
        if grad_aos is not None:
            grad_aos = jnp.where(mask[:, None, None], 1, grad_aos)
        return cls(coords, weights, aos, grad_aos)

    @classmethod
    def from_preloaded(cls, pgrid: PreloadGrid) -> 'Grid':
        return cls(
            jnp.asarray(pgrid.coords),
            jnp.asarray(pgrid.weights),
            jnp.asarray(pgrid.aos),
            jnp.asarray(pgrid.grad_aos) if pgrid.grad_aos is not None else None,
        )

    @classmethod
    def empty(cls) -> 'Grid':
        """
        Empty grid for quicker system construction in pytests
        """
        return cls(
            jnp.array([]),
            jnp.array([]),
            jnp.array([]),
            None,
        )


@dataclass
class System:
    _nuc_pos: FloatAx3  # TODO: rename to non force nuc_pos ?
    atom_z: IntA  # atomic numbers where Z=255 are masked out
    atom_mask: BoolA
    fock_tensors: FockTensors
    grid: Grid

    @property
    def n_atoms(self) -> int:
        return len(self.atom_z)

    @property
    def n_electrons(self) -> UInt1:
        return jnp.sum(self.fock_tensors.occupancies)

    @property
    def atomic_symbol(self) -> List[str]:
        return [elements[z].symbol for z in self.atom_z]

    @classmethod
    def from_preloaded(
        cls,
        psys: PreloadSystem,
        grid: Grid,
    ) -> 'System':
        fock_tensors = FockTensors(
            basis_mask=jnp.asarray(psys.fock_tensors.basis_mask),
            core_hamiltonian=jnp.asarray(
                psys.fock_tensors.core_hamiltonian, dtype=PRECISION.solver
            ),
            electron_repulsion_tensor=jnp.asarray(
                psys.fock_tensors.electron_repulsion_tensor,
                dtype=PRECISION.eri_tensor,
            ),
            eri_sr_tensor=(
                jnp.asarray(psys.fock_tensors.eri_sr_tensor, dtype=PRECISION.eri_tensor)
                if psys.fock_tensors.eri_sr_tensor is not None
                else None
            ),
            eri_lr_tensor=(
                jnp.asarray(psys.fock_tensors.eri_lr_tensor, dtype=PRECISION.eri_tensor)
                if psys.fock_tensors.eri_lr_tensor is not None
                else None
            ),
            occupancies=jnp.asarray(psys.fock_tensors.occupancies, dtype=jnp.uint8),
            overlap=jnp.asarray(psys.fock_tensors.overlap, dtype=PRECISION.solver),
            diagonal_overlap=transformation_matrix(
                psys.fock_tensors.overlap.astype(PRECISION.solver)
            ),
        )

        return cls(
            jnp.asarray(psys.nuc_pos),
            atom_z=jnp.asarray(psys.atom_z),
            atom_mask=jnp.asarray(psys.atom_mask),
            fock_tensors=fock_tensors,
            grid=grid,
        )

    def to_pyscf(self, basis: str) -> gto.Mole:
        """Convert to a PySCF molecule"""
        # manual conversion to bohr, to aid ao test precision by avoiding pySCF's conversion
        atom_z = self.atom_z[self.atom_mask]
        nuc_pos = self._nuc_pos[self.atom_mask]

        occ = self.fock_tensors.occupancies
        if occ.ndim == 1:
            spin = 0
        else:
            spin = occ[0].sum() - occ[1].sum()
        charge = jnp.sum(atom_z) - occ.sum()

        mol = gto.M(
            atom=list(zip(atom_z, nuc_pos * ANGSTROM_TO_BOHR)),
            basis=basis,
            charge=int(charge),
            spin=spin,
            unit='Bohr',
        )
        return mol


def nuclear_energy_fn(nuc_pos: FloatAx3, sys: System) -> Float1:
    """
    Nuclear electrostatic interaction energy.
    Assumes that the input positions are in Angstrom
    """
    idx, jdx = jnp.triu_indices(sys.n_atoms, 1)
    a_a_mask = sys.atom_mask[idx] * sys.atom_mask[jdx]
    u = sys.atom_z[idx] * sys.atom_z[jdx]
    u *= a_a_mask
    rij = nuc_pos[idx, :] - nuc_pos[jdx, :]
    rij *= ANGSTROM_TO_BOHR  # convert distances to Bohr s.t. energies are in Hartree
    return jnp.sum(u * jnp.where(u != 0, 1 / jnp.linalg.norm(rij, axis=1), 0))


def nuclear_energy_and_force(nuc_pos: FloatAx3, sys: System) -> Tuple[Float1, FloatAx3]:
    E, grad = jax.value_and_grad(nuclear_energy_fn)(nuc_pos, sys)
    return E, -grad
