from functools import partial
from typing import Callable, Literal, Tuple

import jax
import jax.numpy as jnp

from deixc.data_generation.utils import BaseDftDeixcTargetGenerator, DeiXCTargets, Timer
from deixc.orbital_transforms import (
    ao_density_perturbation_from_occupied_virtual_rotation,
    ao_to_mo,
    dm_gradient_to_orbital_rotation_gradient,
)
from deixc.scf import (
    DerivativeInformedSelfConsistentFieldSolver,
    DerivativeInformedSolver,
)
from egxc import dataloading
from egxc.dataloading import io
from egxc.discretization import get_grid_fn, get_gto_grid_eval_fn, get_gto_preloader
from egxc.systems import Grid, System, examples, nuclear_energy_fn
from egxc.utils.linalg import modified_generalized_eigenvalue_problem
from egxc.utils.typing import (
    Alignment,
    BoolB,
    FloatBxB,
    FloatBxBxBxB,
    FloatQxBxB,
    FloatRefSCFxBxB,
    FloatRefSCFxOxV,
)
from egxc.xc_energy import DensityFeatures, XCModule, get_functional
from egxc.xc_energy.functionals.classical.hybrid import Hybrid

NO_FREE_PARAMS = {}
EPSILON = 1e-12

CONVERGENCE_THRESHOLD = 1e-4  # mHa  (1e-7 Hartree)


def compute_sample_targets_fn_factory(
    scf_solver: DerivativeInformedSelfConsistentFieldSolver,
    include_forces: bool = False,
    include_d3_dispersion: bool = False,
    include_hessian: bool = False,
) -> Callable[[FloatBxB, System], DeiXCTargets]:
    """Build a JIT-compiled sampling kernel that returns DeiXCTargets for a system.

    Parameters
    ----------
    scf_solver
        SCF solver capable of returning derivative-informed quantities.
    include_forces
        Placeholder flag for computing nuclear forces (currently unsupported).
    include_d3_dispersion
        Placeholder flag for computing D3 dispersion corrections (currently unsupported).
    include_hessian
        Placeholder flag for computing second derivatives (currently unsupported).

    Returns
    -------
    Callable[[FloatBxB, System], DeiXCTargets]
        A function that maps an initial density guess and a system to
        fully-populated DeiXCTargets suitable for DeiXC training workflows.
    """
    assert (
        not include_forces and not include_d3_dispersion and not include_hessian
    )  # TODO: Implement these

    @jax.jit
    def scf_kernel(initial_density_matrix: FloatBxB, sys: System):
        """Run the derivative-informed SCF procedure and collect intermediate tensors."""
        (e_hj, e_xc), (Cs, Ps, Fs, Vs) = scf_solver.apply(
            NO_FREE_PARAMS, initial_density_matrix, sys
        )
        xc_energies = e_xc
        total_energies = e_hj + e_xc + nuclear_energy_fn(sys._nuc_pos, sys)
        return (Cs, Ps, Fs, Vs, total_energies, xc_energies)

    @partial(jax.vmap, in_axes=(0, 0, None, None))
    def _local_lin_resp_fn(
        P0: FloatBxB,
        dP: FloatBxB,
        grid: Grid,
        basis_mask: BoolB,
    ) -> FloatBxB:
        """Evaluate the XC potential linear response for local reference functionals."""
        _, lin_resp = scf_solver.apply(
            NO_FREE_PARAMS,
            P0,
            dP,
            grid,
            basis_mask,
            method=DerivativeInformedSolver.xc_potential_and_linear_response,
        )
        return lin_resp  # type: ignore

    @partial(jax.vmap, in_axes=(0, 0, None, None, None))
    def _hybrid_lin_resp_fn(
        P0: FloatBxB,
        dP: FloatBxB,
        grid: Grid,
        basis_mask: BoolB,
        eri_tensor: FloatBxBxBxB | FloatQxBxB,
    ) -> FloatBxB:
        """Evaluate the reference hybrid XC potential linear response including ERI tensors."""
        _, lin_resp = scf_solver.apply(
            NO_FREE_PARAMS,
            P0,
            dP,
            grid,
            basis_mask,
            method=DerivativeInformedSolver.xc_potential_and_linear_response,
            eri_tensor=eri_tensor,  # non_local_kwargs
        )
        return lin_resp  # type: ignore

    @partial(jax.jit, static_argnames=('n_occ', 'is_hybrid'))
    def direct_minimization_directions_and_linear_response_kernel(
        mo_coeffs: FloatRefSCFxBxB,
        xc_potential_matrices: FloatRefSCFxBxB,
        sys: System,
        n_occ: int,
        is_hybrid: bool,
    ) -> Tuple[FloatRefSCFxOxV, FloatRefSCFxOxV]:
        """Construct orbital rotation gradients and map them through the XC response."""
        orbital_rotation_gradient: FloatRefSCFxOxV = jax.vmap(
            dm_gradient_to_orbital_rotation_gradient, in_axes=(0, 0, None)
        )(xc_potential_matrices, mo_coeffs, n_occ)
        # mo to ao
        perturbations = jax.vmap(
            ao_density_perturbation_from_occupied_virtual_rotation,
            in_axes=(0, 0, None, None, None),
        )(
            orbital_rotation_gradient,
            mo_coeffs,
            sys.fock_tensors.overlap,
            n_occ,
            True,
        )
        if is_hybrid:
            linear_responses_xc_pot = _hybrid_lin_resp_fn(
                mo_coeffs,
                perturbations,
                sys.grid,
                sys.fock_tensors.basis_mask,
                sys.fock_tensors.ert,
            )
        else:
            linear_responses_xc_pot = _local_lin_resp_fn(
                mo_coeffs,
                perturbations,
                sys.grid,
                sys.fock_tensors.basis_mask,
            )
        # ao to mo
        linear_responses_xc_pot = jax.vmap(ao_to_mo, in_axes=(0, 0, None))(
            linear_responses_xc_pot, mo_coeffs, n_occ
        )
        return (
            orbital_rotation_gradient,
            linear_responses_xc_pot,
        )  # TODO: add gradient sanity check

    def compute_sample_targets(initial_density_matrix: FloatBxB, sys: System):
        """Compute DeiXC training targets given an AO density guess and molecular system."""
        timer = Timer().set()
        (
            mo_coeffs,
            density_matrices,
            fock_matrices,
            xc_potential_matrices,
            total_energies,
            xc_energies,
        ) = scf_kernel(initial_density_matrix, sys)
        jax.block_until_ready(density_matrices)
        jax.block_until_ready(mo_coeffs)
        jax.block_until_ready(fock_matrices)
        jax.block_until_ready(total_energies)
        jax.block_until_ready(xc_energies)
        jax.block_until_ready(xc_potential_matrices)
        energy_volatility = abs(xc_energies[-2] - xc_energies[-1])
        if energy_volatility * 1e3 > CONVERGENCE_THRESHOLD:
            raise RuntimeError(
                f'Energy volatility is too high: {energy_volatility * 1e3:.6f} mHa',
            )
        print(f'Volatility: {energy_volatility * 1e3:.6f} mHa')
        timer.log('scf_calculation')
        n_occ = int(jnp.count_nonzero(sys.fock_tensors.occupancies))

        is_hybrid = isinstance(scf_solver.xc_module.functional, Hybrid)
        timer.set()
        xc_minimization_directions, linear_responses_xc_pot = (
            direct_minimization_directions_and_linear_response_kernel(
                mo_coeffs,
                xc_potential_matrices,
                sys,
                n_occ,
                is_hybrid,
            )
        )
        jax.block_until_ready(linear_responses_xc_pot)
        jax.block_until_ready(xc_minimization_directions)
        timer.log('linear_response_calculation')
        timer.set()
        orbital_energies, _ = modified_generalized_eigenvalue_problem(
            fock_matrices[-1], sys.fock_tensors.diagonal_overlap
        )
        timer.log('orbital_energies_calculation')

        out = DeiXCTargets(
            mo_coeffs=mo_coeffs,
            total_energies=total_energies,
            xc_energies=xc_energies,
            xc_potential_matrices=xc_potential_matrices,
            linear_response_xc_pot=linear_responses_xc_pot,
            orbital_energies=orbital_energies,
            density_hessian_diagonal=None,
            forces=None,
            d3_dispersion_energy=None,
            d3_dispersion_forces=None,
            compute_costs=timer.get_timings(),
        )
        out.test_recomputations(
            xc_minimization_directions, density_matrices, xc_potential_matrices
        )
        return out

    return compute_sample_targets


def get_jax_transform_and_dataloader(
    dataset: dataloading.BaseDataset,
    basis: str,
    use_density_fitting: bool,
    workers: int,
    worker_buffer_size: int,
    quadrature_grid_level: int,
    alignment: Alignment = Alignment(1, 1, 1024),
    spin_restricted: bool = True,
    base_initial_density_guess: Literal['minao'] = 'minao',
    center: bool = False,
    shuffle: bool = False,
    random_seed: int = 0,
) -> Tuple[dataloading.ToJaxTransform, dataloading.dataloader.GrainDataLoaderWrapper]:
    max_angular_momentum, basis_fn_preloader = get_gto_preloader(
        basis, dataset.unique_elements
    )
    preloading_transformations = dataloading.get_preload_transform(
        batch_size=1,
        basis=basis,
        spin_restricted=spin_restricted,
        alignment=alignment,
        use_density_fitting=use_density_fitting,
        base_initial_density_guess=base_initial_density_guess,
        center=center,
        basis_fn_preloader=basis_fn_preloader,
    )
    dataloader = dataloading.dataloader.get_individual_dataloader(
        dataset,
        preloading_transformations,
        shuffle,
        workers,
        worker_buffer_size,
        random_seed,
    )

    grid_fn = get_grid_fn(
        quadrature_grid_level,
        dataset.unique_elements,
        alignment.grid,
    )
    basis_fn = get_gto_grid_eval_fn(deriv=1, max_angular_momentum=max_angular_momentum)
    main_thread_transform = dataloading.get_jax_transform(
        grid_fn,
        basis_fn,
    )

    return main_thread_transform, dataloader


class CustomDftTargetGenerator(BaseDftDeixcTargetGenerator):
    aux_data_key = 'deixc'
    backend = 'custom'
    spin_restricted = True
    n_cycles = 20
    workers = 4
    worker_buffer_size = 1

    def __call__(self, chunk_id: int, n_chunks: int) -> None:
        start, end = self.get_chunk_indices(chunk_id, n_chunks)
        subset = dataloading.IndexWrapper(self.dataset, list(range(start, end)))
        assert (
            self.kwargs['use_eri_density_fitting']
            == self.kwargs['use_exchange_density_fitting']
        ), (
            'presently exchange and coulomb terms rely on the same ert_tensor and can only be used together'
        )
        use_density_fitting = self.kwargs['use_eri_density_fitting']
        to_jax_transform, dataloader = get_jax_transform_and_dataloader(
            subset,
            self.kwargs['basis'],
            use_density_fitting,
            self.workers,
            self.worker_buffer_size,
            self.kwargs['quadrature_grid_level'],
            spin_restricted=self.spin_restricted,
        )
        functional = get_functional(
            self.kwargs['xc_str'],
            spin_restricted=self.spin_restricted,
            use_density_fitting=use_density_fitting,
        )
        xc_module = XCModule(functional, DensityFeatures(self.spin_restricted))
        scf_solver = DerivativeInformedSelfConsistentFieldSolver(
            xc_module,
            self.spin_restricted,
            self.n_cycles,
            use_density_fitting,
        )
        sys = examples.get(
            'water',
            self.kwargs['basis'],
            Alignment(),
            use_density_fitting,
            self.spin_restricted,
        )
        scf_solver.init(jax.random.PRNGKey(0), sys.fock_tensors.overlap, sys)
        compute_sample_targets = compute_sample_targets_fn_factory(scf_solver)

        for psys, pvec_basis_fns, _ in dataloader:
            if io.auxiliary_data_exists(self.aux_dir, psys.idx):
                print(
                    f'Skipping {psys.idx}, already exists in {self.aux_dir}', flush=True
                )
                continue
            try:
                P, sys = to_jax_transform(psys, pvec_basis_fns)
                data = compute_sample_targets(P[0], sys)  # TODO: check initial guess
                self.save_deixc_targets(int(psys.idx), data)
            except Exception as e:
                print(f'Error computing DEI-XC targets for {psys.idx}: {e}', flush=True)
                continue
