from jax import config as jax_config

jax_config.update('jax_platform_name', 'gpu')  # must be first
jax_config.update('jax_enable_x64', True)
jax_config.update('jax_default_matmul_precision', 'float32')

# flake8: noqa: E402

import jax

print(jax.devices())

from typing import Tuple

import jax
import jax.numpy as jnp

from deixc.dataset import DEIXCDataset
from egxc.dataloading import QM9
from egxc.discretization import (
    GTOBasis,
    get_grid_fn,
    get_gto_grid_eval_fn,
    get_gto_preloader,
)
from egxc.systems import Grid, System
from egxc.systems.preload import preload_system_using_pyscf
from egxc.utils.typing import Alignment, ElectRepTensorType
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals.base import XCModule
from egxc.xc_energy.functionals.classical import hybrid, mgga

# Configuration
DATA_DIR = 'ANONYMOUS_DIR'
HEAVY_ATOMS_THRESH = 4
BASIS = 'def2-TZVP'
METHOD = 'dft_B3LYP'
ENERGY_TOL = 0.05  # mEh
POTENTIAL_TOL = 1e-6

ERT_TYPE = ElectRepTensorType.DENSITY_FITTED


def build_system(raw_input: Tuple) -> System:
    """Construct a :class:`System` with grid and Fock tensors."""
    nuc_pos, atom_z, charge, spin, aux_data = raw_input
    alignment = Alignment(atom=1, basis=1, grid=1)
    spin_restricted = True
    psys = preload_system_using_pyscf(
        idx=0,
        nuc_pos=nuc_pos,
        atom_z=atom_z,
        charge=charge,
        spin=spin,
        aux_data=aux_data,
        basis=BASIS,
        spin_restricted=spin_restricted,
        alignment=alignment,
        base_initial_density_guess='minao',
        ert_type=ERT_TYPE,
    )
    grid_fn = get_grid_fn(1, set(psys.atom_z[psys.atom_mask]), alignment.grid)
    coords, weights = grid_fn(
        psys.nuc_pos[psys.atom_mask], tuple(map(int, psys.atom_z[psys.atom_mask]))
    )
    l_max, pvec_basis_fn_factory = get_gto_preloader(
        BASIS, set(psys.atom_z[psys.atom_mask])
    )
    preloaded_vec_basis_fns = pvec_basis_fn_factory(psys.atom_z[psys.atom_mask])
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)
    basis_fn = get_gto_grid_eval_fn(1, l_max)
    aos, grad_aos = basis_fn(
        coords,
        psys.nuc_pos[psys.atom_mask],
        vec_basis_fns.primitives,
        vec_basis_fns.compile_statics,
    )
    grid = Grid.create(coords, weights, aos, grad_aos)
    return System.from_preloaded(psys, grid)


def main() -> None:
    dataset = QM9(
        data_dir=DATA_DIR, heavy_atoms_thresh=HEAVY_ATOMS_THRESH, exclude_fluorine=True
    )
    # dataset, _, _ = dataset.random_split(val_fraction=0, seed=0)
    dataset = DEIXCDataset(
        dataset,
        method=METHOD,
        basis=BASIS,
        align_scf_trajectory=None,
        shift_dispersion=False,
    )

    for sample in dataset:
        idx, raw_input, targets = sample
        sys = build_system(raw_input)

        module = XCModule(
            hybrid.Hybrid(
                hybrid.HybridType.B3LYP,
                ERT_TYPE,
                True,
            ),
            DensityFeatures(spin_restricted=True),
        )
        _ = module.init(
            jax.random.PRNGKey(0),
            targets.density_matrix,
            sys.grid,
            eri_tensor=sys.fock_tensors.ert,
        )

        scan_module = XCModule(
            mgga.MetaGGA(),
            # DEIXC(hidden_dim=4),
            DensityFeatures(spin_restricted=True),
        )
        scan_params = scan_module.init(
            jax.random.PRNGKey(0),
            targets.density_matrix,
            sys.grid,
            eri_tensor=sys.fock_tensors.ert,
        )

        scan_energy = scan_module.apply(
            scan_params,
            targets.density_matrix,
            sys.grid,
            eri_tensor=sys.fock_tensors.ert,
        )
        energy = module.apply(
            {},
            targets.density_matrix,
            sys.grid,
            eri_tensor=sys.fock_tensors.ert,
        )
        # potential = module.apply(
        #     {},
        #     targets.density_matrix,
        #     sys.grid,
        #     basis_mask=sys.fock_tensors.basis_mask,
        #     eri_tensor=sys.fock_tensors.ert,
        #     method=XCModule.xc_potential,
        # )

        e_diff = float(jnp.abs(energy - targets.xc_energy))
        # v_diff = float(jnp.max(jnp.abs(potential - targets.xc_potential_matrix)))

        print(f'TARGETENERGY idx {idx}: {targets.xc_energy}')
        print(
            f'############################# SCAN ENERGY idx {idx}: {scan_energy}, target: {targets.xc_energy}; Difference: {scan_energy - targets.xc_energy}'
        )
        if e_diff > ENERGY_TOL * 1e-3:
            print(f'[MISMATCH] idx {idx}: energy diff={e_diff * 1e3:.3f} mHa')
        else:
            print(f'[MATCH] idx {idx}: energy diff={e_diff * 1e3:.3f} mHa')

        # if v_diff > POTENTIAL_TOL:
        #     print(
        #         f"[MISMATCH] idx {idx}: potential max diff={v_diff}"
        #     )
        # else:
        #     print(f"[MATCH] idx {idx}: potential max diff={v_diff}")


if __name__ == '__main__':
    main()
