import warnings
from dataclasses import dataclass
from typing import Tuple

import numpy as onp
from cachetools import LRUCache, cached
from flax.struct import dataclass as flax_dataclass
from numpy.typing import ArrayLike
from pyscf import df, dft, gto
from scipy import linalg

from egxc.utils import pad
from egxc.utils.constants import ANGSTROM_TO_BOHR, MIN_EVP_DIAGONAL_PADDING
from egxc.utils.typing import (
    PRECISION,
    Alignment,
    BaseInitialGuess,
    CompileStaticInt,
    CompileStaticStr,
    NpBool2xB,
    NpBoolA,
    NpBoolB,
    NpDensityMatrix,
    NpFloatAx3,
    NpFloatBxB,
    NpFloatBxBxBxB,
    NpFloatN,
    NpFloatNx3,
    NpFloatNxB,
    NpFloatNxBx3,
    NpFloatQxBxB,
    NpUIntA,
    NpUIntB,
)


def get_aux_basis(basis: str, only_coulomb: bool) -> str:
    if 'def2' in basis:
        if only_coulomb:
            out = 'def2-universal-jfit'
        else:
            out = 'def2-universal-jkfit'
    else:
        mol = gto.M(atom='H 0 0 0; H 0 0 1', basis=basis)
        out = df.addons.predefined_auxbasis(mol, basis)
        if out is None:
            out = 'weigend'
        elif '-jkfit' in out and only_coulomb:
            out = out.replace('-jkfit', '-jfit')
    return out


def compute_electron_occupancy(
    spin: int, n_electrons: int, n_basis_functions: int, spin_restricted: bool
) -> NpUIntB | NpBool2xB:
    B = n_basis_functions
    if spin_restricted:
        assert spin % 2 == 0, 'Spin must be even for restricted'
        occ = onp.zeros(B, dtype=onp.uint8)
        occ[: n_electrons // 2] = 2
    else:
        n_up = n_electrons // 2
        n_down = n_electrons - n_up
        occ_up = onp.zeros(B, dtype=onp.uint8)
        occ_up[:n_up] = 1
        occ_down = onp.zeros(B, dtype=onp.uint8)
        occ_down[:n_down] = 1
        occ = onp.stack((occ_up, occ_down), dtype=onp.uint8)
    return occ


def compute_electron_repulsion_tensor(
    mol: gto.Mole,
    use_density_fitting: bool,
    only_coulomb: bool,
):
    if not use_density_fitting:
        ert = mol.intor('int2e')
    else:
        aux_basis = get_aux_basis(
            mol.basis, only_coulomb=only_coulomb
        )  # TODO: might need to change this for wB97M-V
        auxmol = df.addons.make_auxmol(mol, aux_basis)
        nao = mol.nao
        naux = auxmol.nao
        # ints_3c is the 3-center integral tensor (ij|P), where i and j are the
        # indices of AO basis and P is the auxiliary basis
        ints_3c2e = df.incore.aux_e2(mol, auxmol, intor='int3c2e').reshape(
            nao * nao, naux
        )
        ints_2c2e = auxmol.intor('int2c2e')
        cho = linalg.cholesky(ints_2c2e)
        ert = linalg.solve_triangular(cho.T, ints_3c2e.T, lower=True)
        ert = ert.reshape(naux, nao, nao)

    return ert


@dataclass(frozen=True)
class PreloadFockTensors:
    """
    Tensors that are constant for a given Structure and basis set which are
    used to calculate the Fock matrix.
    """

    basis_mask: NpBoolB
    core_hamiltonian: NpFloatBxB
    electron_repulsion_tensor: NpFloatBxBxBxB | NpFloatQxBxB
    occupancies: NpUIntB | NpBool2xB
    overlap: NpFloatBxB
    eri_sr_tensor: NpFloatBxBxBxB | NpFloatQxBxB | None = None
    eri_lr_tensor: NpFloatBxBxBxB | NpFloatQxBxB | None = None

    def __repr__(self) -> str:
        return (
            f'##### PreloadFockTensors ##### \n'
            f'# basis mask =\n{self.basis_mask},\n'
            f'# Overlap Matrix S =\n{self.overlap},\n'
            f'# Core Hamiltonian H = \n{self.core_hamiltonian},\n'
            f'# Orbital Occupations =\n{self.occupancies},\n'
            f'# Electron Repulsion Tensor (shape: {self.electron_repulsion_tensor.shape})\n'
            f'# Short-range ERI shape: {getattr(self.eri_sr_tensor, "shape", None)}\n'
            f'# Long-range ERI shape: {getattr(self.eri_lr_tensor, "shape", None)}'
        )


@cached(LRUCache(maxsize=float('inf')))
def cached_core_hamiltonian(mol: gto.Mole) -> NpFloatBxB:
    return mol.intor('int1e_kin') + mol.intor('int1e_nuc')


def preload_fock_tensors_using_pyscf(
    mol: gto.Mole,
    spin: int,
    n_electrons: int,
    spin_restricted: bool,
    use_density_fitting: bool,
    alignment: Alignment,
    range_separation: float | None = None,
) -> PreloadFockTensors:
    B = mol.nao  # number of basis functions

    overlap = mol.intor('int1e_ovlp')
    core_hamiltonian = cached_core_hamiltonian(mol)
    ert = compute_electron_repulsion_tensor(mol, use_density_fitting, only_coulomb=True)

    eri_sr, eri_lr = None, None
    if (
        range_separation is not None
    ):  # TODO: this needs to be adapted for wB97M-V references
        assert not use_density_fitting, (
            'Range-separated ERIs only implemented for non-density-fitted tensors'
        )
        with mol.with_short_range_coulomb(range_separation):
            eri_sr = mol.intor('int2e')

        with mol.with_range_coulomb(range_separation):
            eri_lr = mol.intor('int2e')

    occupancies = compute_electron_occupancy(spin, n_electrons, B, spin_restricted)

    basis_mask = onp.ones(B, dtype=bool)
    if alignment.is_aligned:
        assert onp.all(basis_mask), 'Fock tensors should only be padded once'
        b_pad = pad.calc_padding_size(B, alignment.basis)
        basis_mask = onp.pad(basis_mask, (0, b_pad))

        overlap = onp.pad(overlap, ((0, b_pad), (0, b_pad)))
        overlap[~basis_mask, ~basis_mask] = 1.0

        core_hamiltonian = onp.pad(core_hamiltonian, ((0, b_pad), (0, b_pad)))
        # pad diagonal with large non-equal values to avoid issues in the generalized eigenvalue problem of the Fock matrix
        core_hamiltonian[~basis_mask, ~basis_mask] = onp.arange(
            MIN_EVP_DIAGONAL_PADDING, MIN_EVP_DIAGONAL_PADDING + b_pad
        )

        if ert.ndim == 4:
            ert = onp.pad(ert, ((0, b_pad), (0, b_pad), (0, b_pad), (0, b_pad)))
            if range_separation is not None:
                eri_sr = onp.pad(eri_sr, ((0, b_pad), (0, b_pad), (0, b_pad), (0, b_pad)))  # type: ignore
                eri_lr = onp.pad(eri_lr, ((0, b_pad), (0, b_pad), (0, b_pad), (0, b_pad)))  # type: ignore
        else:  # pad density-fitted electron repulsion tensor
            Q = ert.shape[0]
            aux_b_pad = pad.calc_padding_size(Q, alignment.basis)
            ert = onp.pad(ert, ((0, aux_b_pad), (0, b_pad), (0, b_pad)))
            if range_separation is not None:
                eri_sr = onp.pad(eri_sr, ((0, b_pad), (0, b_pad), (0, b_pad)))  # type: ignore
                eri_lr = onp.pad(eri_lr, ((0, b_pad), (0, b_pad), (0, b_pad)))  # type: ignore

        if occupancies.ndim == 1:  # spin-restricted
            occupancies = onp.pad(occupancies, (0, b_pad))
        else:  # spin-unrestricted
            occupancies = onp.pad(occupancies, ((0, 0), (0, b_pad)))

    return PreloadFockTensors(
        basis_mask=basis_mask,
        core_hamiltonian=core_hamiltonian,
        electron_repulsion_tensor=ert,
        occupancies=occupancies,  # type: ignore
        overlap=overlap,
        eri_sr_tensor=eri_sr,
        eri_lr_tensor=eri_lr,
    )


@flax_dataclass
class PreloadGrid:
    coords: NpFloatNx3
    weights: NpFloatN
    aos: NpFloatNxB  # FloatNxB
    grad_aos: NpFloatNxBx3 | None = None  # FloatNxBx3

    @classmethod
    def create(
        cls,
        coords: NpFloatNx3,
        weights: NpFloatN,
        aos: NpFloatNxB,
        grad_aos: NpFloatNxBx3 | None,
        alignment: Alignment,
    ) -> 'PreloadGrid':
        if alignment.is_aligned:
            assert isinstance(alignment.atom, int), 'Atom alignment must be an integer'
            N, B = aos.shape
            n_pad = pad.calc_padding_size(N, alignment.grid)
            b_pad = pad.calc_padding_size(B, alignment.basis)
            coords, weights = pad.pad_quadrature_grid(  # type: ignore
                alignment.grid, coords, weights, backend='numpy'
            )
            aos = onp.pad(aos, ((0, n_pad), (0, b_pad)), mode='edge')
            if grad_aos is not None:
                grad_aos = onp.pad(
                    grad_aos, ((0, n_pad), (0, b_pad), (0, 0)), mode='edge'
                )
        return cls(coords, weights, aos, grad_aos)


@dataclass(frozen=True)
class PreloadSystem:
    """
    Frozen jit compatible dataclass used for cpu-based preloading operations
    for the subsequent construction of gpu-based System objects.
    """

    idx: int  # index of the system in the dataset
    nuc_pos: NpFloatAx3  # nuclei positions
    atom_z: NpUIntA  # atomic numbers  # TODO: remove PermutationInvariantHashableArray
    atom_mask: NpBoolA
    charge: int
    spin: int
    fock_tensors: PreloadFockTensors  # TODO: implement integral engine
    basis: CompileStaticStr
    grid_alignment: CompileStaticInt
    initial_density_matrices: (
        Tuple[NpDensityMatrix] | Tuple[NpDensityMatrix, NpDensityMatrix]
    )

    def __post_init__(self):
        atom_z = self.atom_z[self.atom_mask]  # TODO: check padding
        assert onp.all(atom_z[:-1] <= atom_z[1:]), (
            f'Atom numbers must be sorted: {atom_z}'
        )

    @property
    def number_of_basis_fns(self) -> CompileStaticInt:
        return len(self.fock_tensors.basis_mask)


@cached(LRUCache(maxsize=float('inf')), key=lambda idx, *_: idx)
def cached_molecule(
    idx: int,
    atom_z: NpUIntA,
    nuc_pos: NpFloatAx3,
    basis: str,
    charge: int,
    spin: int,
) -> gto.Mole:
    mol = gto.M(
        atom=list(zip(atom_z, nuc_pos * ANGSTROM_TO_BOHR)),
        basis=basis,
        charge=charge,
        spin=spin,
        unit='Bohr',
    )
    return mol


def preload_system_using_pyscf(
    idx: int,
    nuc_pos: ArrayLike,  # nuclei positions FloatAx3  # type: ignore
    atom_z: ArrayLike,  # atomic numbers IntA  # type: ignore
    charge: int,
    spin: int,
    reference_density: NpFloatBxB | None,
    basis: str,
    spin_restricted: bool,
    alignment: Alignment,
    base_initial_density_guess: BaseInitialGuess,
    use_density_fitting: bool,
    center: bool = False,
    cache_pyscf_mole: bool = True,
    range_separation: float | None = None,
) -> PreloadSystem:
    nuc_pos: NpFloatAx3 = onp.asarray(nuc_pos)
    atom_z: NpUIntA = onp.asarray(atom_z, dtype=onp.uint8)
    assert onp.all(atom_z == onp.sort(atom_z)), 'Atom numbers must be sorted'

    if center:
        nuc_pos -= nuc_pos.mean(axis=0)

    if cache_pyscf_mole:
        mol = cached_molecule(idx, atom_z, nuc_pos, basis, charge, spin)
    else:
        mol = gto.M(
            atom=list(zip(atom_z, nuc_pos * ANGSTROM_TO_BOHR)),
            basis=basis,
            charge=charge,
            spin=spin,
            unit='Bohr',
        )

    n_electrons = int(onp.sum(atom_z) - charge)
    B = mol.nao  # number of basis functions

    fock_tensors = preload_fock_tensors_using_pyscf(
        mol,
        spin,
        n_electrons,
        spin_restricted,
        use_density_fitting,
        alignment,
        range_separation,
    )
    mf = dft.RKS(mol) if spin_restricted else dft.UKS(mol)

    atom_mask = onp.ones_like(atom_z, dtype=bool)
    if reference_density is None:
        initial_density_matrices = (mf.get_init_guess(key=base_initial_density_guess),)
    else:
        initial_density_matrices = (
            mf.get_init_guess(key=base_initial_density_guess),
            reference_density,
        )

    if alignment.is_aligned:
        n_atoms = atom_mask.sum()
        atom_padding = pad.calc_padding_size(n_atoms, alignment.atom)  # type: ignore

        atom_mask = onp.pad(
            atom_mask, (0, atom_padding), mode='constant', constant_values=False
        )

        atom_z = onp.pad(atom_z, (0, atom_padding))
        nuc_pos = onp.pad(nuc_pos, ((0, atom_padding), (0, 0)))

        B = initial_density_matrices[0].shape[-1]
        padding_size = pad.calc_padding_size(B, alignment.basis)
        if spin_restricted:
            temp = []
            for P in initial_density_matrices:
                P = onp.pad(P, ((0, padding_size), (0, padding_size)))
                temp.append(P)
            initial_density_matrices = tuple(temp)
        else:
            temp = []
            assert initial_density_matrices[0].ndim == 3, (
                'Unrestricted density matrix must be 3D'
            )
            for P in initial_density_matrices:
                if P.ndim == 2:  # if reference density matrix is spin-restricted
                    P = 0.5 * onp.stack((P, P), axis=0)
                    warnings.warn(
                        'Using spin-restricted density matrix as reference for spin-unrestricted calculation.'
                    )
                P = onp.pad(P, ((0, 0), (0, padding_size), (0, padding_size)))
                temp.append(P)
            initial_density_matrices = tuple(temp)

    return PreloadSystem(
        idx=idx,
        nuc_pos=nuc_pos.astype(PRECISION.forces),
        atom_z=atom_z.astype(onp.uint8),
        atom_mask=atom_mask,
        charge=charge,
        spin=spin,
        fock_tensors=fock_tensors,
        basis=basis,
        grid_alignment=alignment.grid,
        initial_density_matrices=initial_density_matrices,
    )
