from typing import Tuple

import jax
import jax.numpy as jnp
import pytest
from utils import set_jax_testing_config, system_from_preloaded

from deixc import orbital_transforms
from deixc.dataset import DEIXCDataset, DEIXCTargets
from egxc.dataloading import (
    DatasetEnsemble,
    get_preload_transform,
    get_psys_and_dataloaders,
    key_to_dataset,
)
from egxc.discretization import get_gto_preloader
from egxc.solver.fock import FockMatrix
from egxc.systems import Grid, System
from egxc.training.loss.density import get_density_mean_field_error_fn
from egxc.utils import linalg
from egxc.utils.typing import Alignment, NnParams, NpFloatBxB
from egxc.xc_energy import XCModule, functionals
from egxc.xc_energy.features import DensityFeatures

set_jax_testing_config()
pytestmark = pytest.mark.quick


BASIS = 'def2-SVP'
XC = 'SCAN'
BACKEND = 'custom'  # pyscf or custom
GRID_LEVEL = 1
IS_DENSITY_FITTED = True


def init_iterator():
    qm9 = key_to_dataset['qm9'](
        initial_ref_density_method_key='ks_dft',
        initial_ref_density_method_kwargs={
            'xc_str': 'LDA',
            'basis': BASIS,
            'backend': 'pyscf',
            'use_eri_density_fitting': True,
            'use_exchange_density_fitting': True,
            'spin_restricted': True,
            'quadrature_grid_level': GRID_LEVEL,
        },
        data_dir='ANONYMOUS_DIR',
        exclude_fluorine=True,
        heavy_atoms_thresh=4,
    )
    unique_elements = qm9.unique_elements
    qm9 = DEIXCDataset(
        qm9,
        method_key='ks_dft',
        method_kwargs={
            'xc_str': XC,
            'basis': BASIS,
            'backend': BACKEND,
            'use_eri_density_fitting': True,
            'use_exchange_density_fitting': True,
            'spin_restricted': True,
            'quadrature_grid_level': GRID_LEVEL,
        },
        align_scf_trajectory=10,
        shift_dispersion=False,
    )
    qm9 = DatasetEnsemble.infer_split(qm9, data_split_seed=0, val_fraction=0.1)

    preload_transform = get_preload_transform(
        batch_size=1,
        basis=BASIS,
        spin_restricted=True,
        alignment=Alignment(atom=1, grid=256),
        use_density_fitting=IS_DENSITY_FITTED,
        base_initial_density_guess='minao',
        basis_fn_preloader=get_gto_preloader(BASIS, unique_elements)[1],
        center=False,
    )
    _, dataloaders = get_psys_and_dataloaders(
        qm9,
        preload_transform,
        shuffle=False,
        workers=1,
        worker_buffer_size=1,
        shuffling_seed=0,
    )
    return dataloaders.train.__iter__()


ITERATOR = init_iterator()


def get_sample_and_targets_and_densities() -> Tuple[
    System, DEIXCTargets, Tuple[NpFloatBxB, NpFloatBxB]
]:
    psys, _, targets = ITERATOR.__next__()
    sys = system_from_preloaded(psys, BASIS, 1, 512)
    return sys, targets, psys.initial_density_matrices


def get_model_and_params(
    xc: str, sys: System, density_matrix: NpFloatBxB
) -> Tuple[XCModule, NnParams]:
    model = XCModule(
        functional=functionals.get_functional(name=xc),
        feature_fn=DensityFeatures(spin_restricted=True),
    )
    if xc == 'lda':
        return model, {}
    params = model.init(jax.random.PRNGKey(0), density_matrix, sys.grid)
    return model, params


SYS, TARGETS, _ = get_sample_and_targets_and_densities()
MODEL, PARAMS = get_model_and_params('scan', SYS, TARGETS.density_matrix)  # type: ignore


def test_xc_potential_consistency():
    V_1 = MODEL.apply(
        PARAMS,
        TARGETS.density_matrix,
        SYS.grid,
        SYS.fock_tensors.basis_mask,
        method=MODEL.xc_potential,
    )
    V_2, _ = MODEL.apply(
        PARAMS,
        TARGETS.density_matrix,
        TARGETS.density_matrix,
        SYS.grid,
        SYS.fock_tensors.basis_mask,
        method=MODEL.xc_potential_and_linear_response,
    )
    assert jnp.allclose(V_1, V_2)  # type: ignore


def test_density_gradient_direction():
    EPSILON = 0.0001
    for idx in range(5):
        V = MODEL.apply(
            PARAMS,
            TARGETS.density_matrices[idx],
            SYS.grid,
            SYS.fock_tensors.basis_mask,
            method=MODEL.xc_potential,
        )

        orb_rot_grad = orbital_transforms.dm_gradient_to_orbital_rotation_gradient(
            V, TARGETS.mo_coeffs[idx], TARGETS.n_occ
        )
        delta_density = (
            orbital_transforms.ao_density_perturbation_from_occupied_virtual_rotation(
                orb_rot_grad,
                TARGETS.mo_coeffs[idx],
                SYS.fock_tensors.overlap,
                TARGETS.n_occ,
                normalize=True,
            )
        )
        E_0 = MODEL.apply(
            PARAMS,
            TARGETS.density_matrices[idx],
            SYS.grid,
            method=MODEL.xc_energy,
        )
        E_1 = MODEL.apply(
            PARAMS,
            TARGETS.density_matrices[idx] + EPSILON * delta_density,
            SYS.grid,
            method=MODEL.xc_energy,
        )
        E_2 = MODEL.apply(
            PARAMS,
            TARGETS.density_matrices[idx] - EPSILON * delta_density,
            SYS.grid,
            method=MODEL.xc_energy,
        )
        assert E_1 < E_0, (
            f'xc energy should decrease along the direction of steepest descent but E_0={E_0} < E_1={E_1}'
        )
        assert E_2 > E_0, (
            f'xc energy should increase in the opposite direction but E_0={E_0} > E_2={E_2}'
        )


def test_zero_gradient_at_equilibrium():
    sys, targets, (minao_density, lda_density) = get_sample_and_targets_and_densities()
    sys = System(
        grid=Grid(
            grad_aos=None,
            **{k: v for k, v in sys.grid.__dict__.items() if k != 'grad_aos'},
        ),
        **{k: v for k, v in sys.__dict__.items() if k != 'grid'},
    )
    functional, _ = get_model_and_params('lda', sys, lda_density)
    fock_matrix = FockMatrix(
        xc_module=functional,
        use_density_fitting=IS_DENSITY_FITTED,
        spin_restricted=True,
    )

    def get_fock_matrix(sys: System, density_matrix: NpFloatBxB):
        F = fock_matrix.apply(
            {},
            sys._nuc_pos,
            density_matrix,
            sys,
            method=fock_matrix.fock_matrix,
        )
        return F

    def compute_ov_gradient(density_matrix: NpFloatBxB):
        F = get_fock_matrix(sys, density_matrix)
        X = linalg.transformation_matrix(sys.fock_tensors.overlap)

        def update_fock_tensors(sys: System, fock_matrix: NpFloatBxB):
            e, C = linalg.modified_generalized_eigenvalue_problem(fock_matrix, X)
            P = linalg.coeffs_to_density_matrix(C, targets.occupancies)
            F = get_fock_matrix(sys, P)
            return F, P, C

        F, P, C = update_fock_tensors(sys, F)  # type: ignore
        orb_rot_grad = orbital_transforms.dm_gradient_to_orbital_rotation_gradient(
            F, C, targets.n_occ
        )
        return orb_rot_grad, P

    initial_ov_gradient, _ = compute_ov_gradient(minao_density)
    equilibrium_ov_gradient, lda_density_2 = compute_ov_gradient(lda_density)

    initial_norm = jnp.linalg.norm(initial_ov_gradient)
    equilibrium_norm = jnp.linalg.norm(equilibrium_ov_gradient)

    density_mean_field_error_fn = get_density_mean_field_error_fn(
        spin_restricted=True,
        use_density_fitting=IS_DENSITY_FITTED,
        scale_per_electron=False,
    )
    density_mae = (
        density_mean_field_error_fn(
            sys,
            lda_density,  # type: ignore
            lda_density_2,
        )
        * 1e3
    )
    assert density_mae < 1e-6, (
        f'LDA density matrix is not the equilibrium density matrix but the MAE is {density_mae:.6f} mHa, which is too large'
    )
    assert initial_norm > equilibrium_norm, (
        f'In equilibrium, the ov gradient should vanish but it is {equilibrium_norm} compared to {initial_norm}'
    )
