import os
import warnings
from typing import Any, Literal, Tuple

import numpy as onp
from numpy.typing import ArrayLike
from pyscf import dft, gto, scf

from egxc.dataloading import io
from egxc.dataloading.datasets.base import BaseDataset
from egxc.systems.preload import get_aux_basis
from egxc.utils.typing import AuxDataKey, MethodKey, NpFloatBxB


class BaseGenerator:
    """
    Base class for auxiliary data generators.
    """

    method_key: MethodKey
    aux_data_key: AuxDataKey
    backend: Literal['pyscf', 'custom']

    def __init__(self, dataset: BaseDataset, **kwargs: Any):
        self.dataset = dataset
        self.kwargs = kwargs

        self.aux_dir = io.auxiliary_data_directory(
            dataset.auxiliary_data_directory,
            self.aux_data_key,
            self.method_key,
            **self.kwargs,
        )
        os.makedirs(self.aux_dir, exist_ok=True)

    def __call__(self, chunk_id: int, n_chunks: int) -> None:
        raise NotImplementedError('This method should be implemented by subclasses.')

    def get_chunk_indices(self, chunk_id: int, n_chunks: int) -> Tuple[int, int]:
        start = chunk_id * len(self.dataset) // n_chunks
        end = min((chunk_id + 1) * len(self.dataset) // n_chunks, len(self.dataset))
        return start, end

    def save_initial_guess(
        self, idx: int, density_matrix: NpFloatBxB, energy: float
    ) -> None:
        out = {
            'energy': onp.asarray(energy),
            'density_matrix': onp.asarray(density_matrix),
        }
        assert out['density_matrix'].dtype == onp.float64
        io.auxiliary_data_save(self.aux_dir, idx, out)


class BaseDftGenerator(BaseGenerator):
    """
    Base class for DFT-based DEI-XC target generators.
    """

    backend: Literal['pyscf', 'custom']
    method_key = 'ks_dft'

    def __post_init__(self):
        warnings.warn('DFT data generation presently assumes spin-restriction')


def get_pyscf_mol_and_ks_meanfield(
    atom_z: ArrayLike,
    nuc_pos: ArrayLike,
    basis: str,
    spin: int,
    charge: int,
    xc_str: str,
    use_eri_density_fitting: bool,
    use_exchange_density_fitting: bool,
    spin_restricted: bool,
    quadrature_grid_level: int,
) -> Tuple[gto.Mole, scf.hf.SCF]:
    mol = gto.M(
        atom=list(zip(atom_z, nuc_pos)),  # type: ignore
        basis=basis,
        spin=spin,
        charge=charge,
    )
    mf = dft.RKS(mol, xc=xc_str) if spin_restricted else dft.UKS(mol, xc=xc_str)
    mf.grids.level = quadrature_grid_level
    if use_eri_density_fitting:
        only_coulomb = not use_exchange_density_fitting
        aux_basis = get_aux_basis(mol.basis, only_coulomb)
        mf = mf.density_fit(auxbasis=aux_basis, only_dfj=True)  # TODO
    return mol, mf  # type: ignore


class PyscfDftTargetGenerator(BaseDftGenerator):
    aux_data_key = 'initial_guess'
    backend = 'pyscf'

    def __call__(self, chunk_id: int, n_chunks: int) -> None:
        start, end = self.get_chunk_indices(chunk_id, n_chunks)
        for i in range(start, end):
            sample = self.dataset[i]
            idx, (nuc_pos, atom_z, charge, spin, _), _ = sample
            if io.auxiliary_data_exists(self.aux_dir, idx):
                print(f'Skipping {idx}, already exists in {self.aux_dir}', flush=True)
                continue
            print(f'Recomputing sample {idx}...', flush=True)
            _, mf = get_pyscf_mol_and_ks_meanfield(
                atom_z,
                nuc_pos,
                self.kwargs['basis'],
                spin,
                charge,
                self.kwargs['xc_str'],
                self.kwargs['use_eri_density_fitting'],
                self.kwargs['use_exchange_density_fitting'],
                self.kwargs['spin_restricted'],
                self.kwargs['quadrature_grid_level'],
            )
            mf.kernel()
            self.save_initial_guess(int(idx), mf.make_rdm1(), mf.e_tot)
