from typing import Tuple

import jax
import jax.numpy as jnp
import optax
import pytest
from jax import random
from utils import (
    assert_either_abs_or_rel_close,
    assert_is_close,
    call_module_as_function,
    set_jax_testing_config,
    system_from_preloaded,
)

from egxc.solver import fock, scf
from egxc.systems import examples
from egxc.systems.base import nuclear_energy_fn
from egxc.utils import linalg
from egxc.utils.typing import Alignment
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals.classical import mgga
from egxc.xc_energy.functionals.learnable import XCDiff

set_jax_testing_config()
BASIS = '6-31G(d)'


def _get_aspirin_system(
    alignment: Alignment, spin_restricted: bool, use_density_fitting: bool
):
    psys = examples.get_preloaded(
        'aspirin',
        BASIS,
        use_density_fitting=use_density_fitting,
        spin_restricted=spin_restricted,
        alignment=alignment,
    )
    sys = system_from_preloaded(
        psys, BASIS, 1, grid_alignment=alignment.grid, basis_alignment=alignment.basis
    )
    P0 = psys.initial_density_matrices[0]
    return sys, psys, P0


@pytest.mark.parametrize(
    'alignment',
    [Alignment(1, 10, 1), Alignment(12, 1, 1), Alignment(1, 1, 512 * 100)],
    ids=['basis', 'atom', 'grid'],
)
def test_fock_matrix(alignment: Alignment, spin_restricted: bool = True):
    use_density_fitting = True
    xc_mod = fock.XCModule(XCDiff(hidden_dim=8), DensityFeatures(spin_restricted))
    fock_module = fock.FockMatrix(xc_mod, use_density_fitting, spin_restricted)

    sys_pad, psys_pad, P0_pad = _get_aspirin_system(
        alignment, spin_restricted, use_density_fitting
    )

    H_pad, J_pad, V_pad = call_module_as_function(  # type: ignore
        fock_module, sys_pad._nuc_pos, P0_pad, sys_pad, method='fock_matrix_contributions'
    )
    assert H_pad.shape[0] % alignment.basis == 0, (
        f'H.shape[0] % alignment.basis = {H_pad.shape[0] % alignment.basis}'
    )

    sys, psys, P0 = _get_aspirin_system(Alignment(), spin_restricted, use_density_fitting)
    H, J, V = call_module_as_function(  # type: ignore
        fock_module, sys._nuc_pos, P0, sys, method='fock_matrix_contributions'
    )
    print('#' * 20, H.dtype, J.dtype, V.dtype)  # type: ignore
    F = H + J + V
    F_pad = H_pad + J_pad + V_pad
    basis_pad_size = H_pad.shape[0] - H.shape[0]

    if basis_pad_size > 0:
        H_pad = H_pad[:-basis_pad_size, :-basis_pad_size]
        J_pad = J_pad[:-basis_pad_size, :-basis_pad_size]  # type: ignore
        V_pad = V_pad[:-basis_pad_size, :-basis_pad_size]

    assert_is_close(H, H_pad, name='H_pad and H')
    assert_is_close(J, J_pad, name='J_pad and J')  # type: ignore
    assert_is_close(V, V_pad, name='V_pad and V')

    if basis_pad_size > 0:
        assert (F_pad[-basis_pad_size:, :-basis_pad_size] == 0).all()  # top right block
        assert (F_pad[:-basis_pad_size, -basis_pad_size:] == 0).all()  # bottom left block

    def new_density_matrix(F, X, occupancies):
        eigenvalues, C = linalg.modified_generalized_eigenvalue_problem(F, X)
        return linalg.coeffs_to_density_matrix(C, occupancies), eigenvalues

    X_pad = linalg.transformation_matrix(
        psys_pad.fock_tensors.overlap  # type: ignore
    )  # type: ignore
    P_pad, eigvals_pad = new_density_matrix(
        F_pad, X_pad, sys_pad.fock_tensors.occupancies
    )
    X = linalg.transformation_matrix(
        psys.fock_tensors.overlap  # type: ignore
    )
    P, eigvals = new_density_matrix(F, X, sys.fock_tensors.occupancies)
    if basis_pad_size > 0:
        P_pad = P_pad[:-basis_pad_size, :-basis_pad_size]
        eigvals_pad = eigvals_pad[:-basis_pad_size]
    assert_either_abs_or_rel_close(P, P_pad, name='P_pad and P', absolute_tolerance=2e-14)
    assert_either_abs_or_rel_close(eigvals, eigvals_pad, name='eigvals_pad and eigvals')


@pytest.mark.parametrize('alignment', [Alignment(1, 1, 128 * 100)], ids=['grid'])
def test_xc_potential(alignment):
    spin_restricted = True
    use_density_fitting = True
    xc_module = fock.XCModule(mgga.MetaGGA('scan'), DensityFeatures(spin_restricted))
    sys_pad, _, P0_pad = _get_aspirin_system(
        alignment, spin_restricted, use_density_fitting
    )
    V_xc_pad = call_module_as_function(
        xc_module,
        P0_pad,
        sys_pad.grid,
        sys_pad.fock_tensors.basis_mask,
        method='xc_potential',
    )

    sys, _, P0 = _get_aspirin_system(Alignment(), spin_restricted, use_density_fitting)
    V_xc = call_module_as_function(
        xc_module,
        P0,
        sys.grid,
        sys.fock_tensors.basis_mask,
        method='xc_potential',
    )
    delta_V_xc = jnp.max(jnp.abs(V_xc_pad - V_xc))  # type: ignore
    assert delta_V_xc < 2e-14, f'V_xc_pad and V_xc do not match: {delta_V_xc}'


def _run_scf(alignment: Alignment, spin_restricted: bool) -> Tuple[float, float]:
    """
    Run SCF cycle for a given alignment and spin restriction.
    Returns the final energy and the xc energy of the initial guess.
    """
    use_density_fitting = True
    xc_mod = fock.XCModule(XCDiff(hidden_dim=8), DensityFeatures(spin_restricted))
    scf_solver = scf.SelfConsistentFieldSolver(
        xc_mod, 15, use_density_fitting, spin_restricted, 'DIIS'
    )
    psys = examples.get_preloaded(
        'h2o',
        BASIS,
        use_density_fitting=use_density_fitting,
        spin_restricted=spin_restricted,
        alignment=alignment,
    )
    sys = system_from_preloaded(
        psys, BASIS, 1, alignment.grid, basis_alignment=alignment.basis
    )

    initial_xc_energy = call_module_as_function(
        xc_mod, psys.initial_density_matrices[0], sys.grid
    )
    assert initial_xc_energy.dtype == jnp.float64  # type: ignore

    energies, _ = call_module_as_function(
        scf_solver, psys.initial_density_matrices[0], sys
    )

    e_final = (energies[0] + energies[1])[-1] + nuclear_energy_fn(sys._nuc_pos, sys)
    return e_final, initial_xc_energy  # type: ignore


@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
def test_scf_cycle_grid_padding(spin_restricted):
    unpadded = _run_scf(Alignment(), spin_restricted)

    padded = _run_scf(Alignment(1, 1, 512), spin_restricted)  # no atom wise padding
    assert not jnp.isnan(padded[0])
    assert abs(padded[1] - unpadded[1]) < 1e-12, (
        f'Initial xc energies should not be affected by padding, got delta of {padded[1] - unpadded[1]}'
    )
    assert abs(padded[0] - unpadded[0]) < 1e-12, (
        f'Final energies should not be affected by padding, got delta of {padded[0] - unpadded[0]}'
    )


@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
def test_scf_cycle_atom_padding(spin_restricted):
    unpadded = _run_scf(Alignment(), spin_restricted)

    padded = _run_scf(Alignment(12, 1, 1), spin_restricted)  # no atom wise padding
    assert not jnp.isnan(padded[0])
    assert (
        abs(padded[1] - unpadded[1]) < 1e-12  # type: ignore
    ), (
        f'Initial xc energies should not be affected by padding, got delta of {padded[1] - unpadded[1]}'
    )  # type: ignore
    assert padded[0] == unpadded[0], (
        f'Final energies should not be affected by padding, got {padded[0]} and {unpadded[0]}'
    )


@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
def test_scf_cycle_basis_padding(spin_restricted):
    unpadded = _run_scf(Alignment(), spin_restricted)

    padded = _run_scf(Alignment(1, 3, 1), spin_restricted)  # no atom wise padding
    assert not jnp.isnan(padded[0])
    assert (
        abs(padded[1] - unpadded[1]) < 1e-12  # type: ignore
    ), (
        f'Initial xc energies should not be affected by padding, got delta of {padded[1] - unpadded[1]}'
    )  # type: ignore
    assert (
        abs(padded[0] - unpadded[0]) < 1e-12  # type: ignore
    ), (
        f'Final energies should not be affected by padding, got {padded[0]} and {unpadded[0]}'
    )


@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
def test_scf_cycle_full_padding(spin_restricted):
    unpadded = _run_scf(Alignment(), spin_restricted)

    padded = _run_scf(Alignment(4, 12, 512), spin_restricted)  # no atom wise padding
    assert not jnp.isnan(padded[0])
    assert (
        abs(padded[1] - unpadded[1]) < 1e-12  # type: ignore
    ), (
        f'Initial xc energies should not be affected by padding, got delta of {padded[1] - unpadded[1]}'
    )  # type: ignore
    assert (
        abs(padded[0] - unpadded[0]) < 1e-12  # type: ignore
    ), (
        f'Final energies should not be affected by padding, got {padded[0]} and {unpadded[0]}'
    )


@pytest.mark.slow
@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
def test_update_step(spin_restricted):
    def update_step(alignment: Alignment):
        use_density_fitting = True
        xc_mod = fock.XCModule(XCDiff(hidden_dim=8), DensityFeatures(spin_restricted))
        scf_solver = scf.SelfConsistentFieldSolver(
            xc_mod, 5, use_density_fitting, spin_restricted, 'DIIS'
        )
        psys = examples.get_preloaded(
            'h2o',
            BASIS,
            use_density_fitting=use_density_fitting,
            spin_restricted=spin_restricted,
            alignment=alignment,
        )

        sys = system_from_preloaded(
            psys, BASIS, 1, grid_alignment=alignment.grid, basis_alignment=alignment.basis
        )

        params = xc_mod.init(
            random.PRNGKey(0), psys.initial_density_matrices[0], sys.grid
        )
        assert (
            xc_mod.apply(params, psys.initial_density_matrices[0], sys.grid).dtype  # type: ignore
            == jnp.float64
        )
        assert (
            xc_mod.apply(
                params,
                psys.initial_density_matrices[0],
                sys.grid,
                sys.fock_tensors.basis_mask,
                method=xc_mod.xc_potential,
            ).dtype  # type: ignore
            == jnp.float64
        )

        params = scf_solver.init(random.PRNGKey(0), psys.initial_density_matrices[0], sys)

        def loss_fn(params):
            TARGET_ENERGY = -76.38566321214728  # RKS B3LYP/6-31G(d) from PySCF
            energies, _ = scf_solver.apply(params, psys.initial_density_matrices[0], sys)
            e_final = (energies[0] + energies[1])[-1] + nuclear_energy_fn(
                sys._nuc_pos, sys
            )
            loss = (e_final - TARGET_ENERGY) ** 2
            return loss, e_final

        opt = optax.adam(
            1e-3
        )  # deliberately small learning rate, to guarantee loss decrease
        opt_state = opt.init(params)

        @jax.jit
        def step(params, opt_state):
            (loss, energy), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
            updates, opt_state = opt.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss, energy

        loss_list = []
        for i in range(2):
            params, opt_state, loss, energy = step(params, opt_state)
            print(f'Iteration {i}, loss: {loss}, energy: {energy}')
            loss_list.append(loss)

        return loss_list

    loss1 = update_step(Alignment(1, 1, 1))

    loss2 = update_step(Alignment(1, 1, 512))
    assert loss2[1] < loss2[0], f'Loss is not decreasing with grid padding {loss2}'
    loss3 = update_step(Alignment(12, 1, 1))
    assert loss3[1] < loss3[0], f'Loss is not decreasing with atom padding {loss3}'
    loss4 = update_step(Alignment(1, 4, 1))
    assert loss4[1] < loss4[0], f'Loss is not decreasing with basis padding {loss4}'

    assert loss2[0] + loss3[0] + loss4[0] - 3 * loss1[0] < 1e-12, (
        f'Loss deviates {loss1[0]} {loss2[0]} {loss3[0]} {loss4[0]}'
    )
    assert loss2[1] + loss3[1] + loss4[1] - 3 * loss1[1] < 1e-12, (
        f'Loss deviates {loss1[1]} {loss2[1]} {loss3[1]} {loss4[1]}'
    )
