from typing import Tuple

import einops
from pyscf import dft, gto

from egxc.discretization import (
    GTOBasis,
    get_grid_fn,
    get_gto_grid_eval_fn,
    get_gto_preloader,
)
from egxc.systems import Grid, System
from egxc.systems.preload import PreloadGrid, PreloadSystem
from egxc.utils.constants import ANGSTROM_TO_BOHR
from egxc.utils.pad import pad_ao_values
from egxc.utils.typing import Alignment, cast_to_integer_tuple


def system_from_preloaded(
    psys: PreloadSystem,
    basis: str,
    level: int,
    grid_alignment: int,
    basis_alignment: int = 1,
) -> System:
    """
    Builds a system from a preloaded system, by adding a grid and basis functions.
    This method is used for testing purposes only.
    """
    grid_fn = get_grid_fn(level, set(psys.atom_z), grid_alignment)
    coords, weights = grid_fn(
        psys.nuc_pos[psys.atom_mask], cast_to_integer_tuple(psys.atom_z[psys.atom_mask])
    )
    l_max, pvec_basis_fn_factory = get_gto_preloader(
        basis, set(psys.atom_z[psys.atom_mask])
    )
    preloaded_vec_basis_fns = pvec_basis_fn_factory(psys.atom_z[psys.atom_mask])
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)
    basis_fn = get_gto_grid_eval_fn(1, l_max)
    aos, grad_aos = basis_fn(
        coords,
        psys.nuc_pos[psys.atom_mask],
        vec_basis_fns.radial_primitives,
        vec_basis_fns.compile_statics,
    )
    aos, grad_aos = pad_ao_values(aos, grad_aos, basis_alignment)
    grid = Grid.create(coords, weights, aos, grad_aos)
    return System.from_preloaded(psys, grid)


def vec_basis_fns_from_preloaded(psys: PreloadSystem, basis: str) -> Tuple[int, GTOBasis]:
    """
    Builds a VecBasisFns from a preloaded system for testing purposes.
    """
    max_angular_momentum, pvec_basis_fn_factory = get_gto_preloader(
        basis, set(psys.atom_z[psys.atom_mask])
    )
    preloaded_vec_basis_fns = pvec_basis_fn_factory(psys.atom_z[psys.atom_mask])
    return max_angular_momentum, GTOBasis.from_preloaded(preloaded_vec_basis_fns)


def preload_grid_using_pyscf(
    psys: PreloadSystem, spin_restricted: bool, grid_level: int, alignment: Alignment
) -> PreloadGrid:
    """
    Pyscf based preloading of the grid for testing purposes
    """
    mol = gto.M(
        atom=list(
            zip(
                psys.atom_z[psys.atom_mask],
                psys.nuc_pos[psys.atom_mask] * ANGSTROM_TO_BOHR,
            )
        ),
        basis=psys.basis,
        charge=psys.charge,
        spin=psys.spin,
    )
    mf = dft.RKS(mol) if spin_restricted else dft.UKS(mol)
    mf.grids.level = grid_level
    mf.grids.build()
    coords = mf.grids.coords
    weights = mf.grids.weights
    aos_and_grad_aos = mol.eval_gto('GTOval_sph_deriv1', coords)
    aos = aos_and_grad_aos[0]
    grad_aos = einops.rearrange(aos_and_grad_aos[1:], 's n b -> n b s')
    return PreloadGrid.create(coords, weights, aos, grad_aos, alignment)  # type: ignore
