from __future__ import annotations

import time
from typing import Callable, List, Literal, ParamSpec, Tuple, TypeVar

import jax
import jax.numpy as jnp
import numpy as np
from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from egxc.dataloading.datasets.qm9 import QM9
from egxc.discretization import GTOBasis, get_gto_grid_eval_fn, get_gto_preloader
from egxc.discretization.grids.quadrature import get_grid_fn
from egxc.utils.typing import cast_to_integer_tuple

# JAX/global config
compilation_cache.set_cache_dir('./caches/jax_compile/benchmark')
config.update('jax_platform_name', 'gpu')
config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')


ex = Experiment()


@ex.config
def default_config():
    # Data
    data_dir = 'ANONYMOUS_DIR'
    seed = 0
    exclude_fluorine = True
    heavy_atoms_thresh = 9  # only used for dataset metadata/splits

    # Benchmark
    num_samples = 10
    grid_level = 1
    grid_alignment = 512
    deriv = 1  # 0: values, 1: values+gradients

    # Basis sets to benchmark
    bases = ['def2-TZVPD']


def _fixed_indices(n: int, total: int, seed: int) -> List[int]:
    rng = np.random.RandomState(seed)
    idx = rng.permutation(total)[:n].tolist()
    idx.sort()
    return idx


P = ParamSpec('P')
T = TypeVar('T')


def _timeit(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Tuple[float, T]:
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    jax.block_until_ready(out)
    t1 = time.perf_counter()
    return (t1 - t0), out


@ex.automain
def main(
    data_dir: str,
    seed: int,
    exclude_fluorine: bool,
    heavy_atoms_thresh: int | Literal['debug', 'debug_larger'],
    num_samples: int,
    grid_level: int,
    grid_alignment: int,
    deriv: int,
    bases: List[str],
):
    # Load dataset metadata and pick fixed samples
    ds = QM9(
        data_dir=data_dir,
        heavy_atoms_thresh=heavy_atoms_thresh,  # used for splits only
        exclude_fluorine=exclude_fluorine,
    )
    total_len = len(ds)
    if total_len == 0:
        raise RuntimeError('QM9 dataset seems empty. Check data_dir or internet access.')
    sel_idx = _fixed_indices(min(num_samples, total_len), total_len, seed)

    # Preload basic sample info
    samples = []
    for i in sel_idx:
        _, (nuc_pos, atom_z, _, _, _), _ = ds[i]
        samples.append((np.asarray(nuc_pos), np.asarray(atom_z, dtype=np.uint32)))

    # Elements vocabulary for basis assembly
    elements_vocab = set(ds.unique_elements)

    print(f'Benchmarking AO evaluation on {len(samples)} QM9 molecules')
    print(f'Indices: {sel_idx}')
    print(f'Elements vocabulary: {sorted(elements_vocab)}')
    print(f'Grid level: {grid_level}, alignment: {grid_alignment}')
    print(f'Derivative order: {deriv}')

    # Build quadrature grid function once per configuration and compute grids for selected samples
    grid_fn = get_grid_fn(grid_level, elements_vocab, grid_alignment)
    sample_grids = []
    Ns = []
    for nuc_pos, atom_z in samples:
        coords, weights = grid_fn(nuc_pos, cast_to_integer_tuple(atom_z))
        sample_grids.append((coords, weights))
        Ns.append(int(coords.shape[0]))
    print(
        f'Grid points stats (N): mean={np.mean(Ns):.0f}  median={np.median(Ns):.0f}  min={np.min(Ns)}  max={np.max(Ns)}'
    )

    for basis_str in bases:
        print('\n=== Basis:', basis_str, '===')

        # Preloading part 1: build vec-basis assembly (per basis across vocab)
        t_build, (max_L, vec_assembly) = _timeit(
            lambda: get_gto_preloader(basis_str, elements_vocab)
        )
        # Prepare AO kernel
        basis_fn = get_gto_grid_eval_fn(deriv=deriv, max_angular_momentum=max_L)

        preload_times = []
        eval_times = []

        # Iterate samples
        for k, (nuc_pos, atom_z) in enumerate(samples):
            grid, _ = sample_grids[k]

            # Preloading part 2: construct per-sample VecBasisFns
            t_pre, preloaded = _timeit(vec_assembly, atom_z)
            vec_basis = GTOBasis.from_preloaded(preloaded)
            preload_times.append(
                t_build + t_pre if k == 0 else t_pre
            )  # include build time once

            # Evaluation: time basis function call (includes JIT on first call for this shape)
            nuc_pos_jnp = jnp.asarray(nuc_pos)
            t_eval, _ = _timeit(
                basis_fn,
                grid,
                nuc_pos_jnp,
                vec_basis.radial_primitives,
                vec_basis.compile_statics,
            )
            eval_times.append(t_eval)

        # Aggregate
        pre_ms = np.array(preload_times) * 1e3
        ev_ms = np.array(eval_times) * 1e3
        print(
            f'Preload times   (ms): mean={pre_ms.mean():.2f}  median={np.median(pre_ms):.2f}  min={pre_ms.min():.2f}  max={pre_ms.max():.2f}'
        )
        print(
            f'Evaluation including compile times(ms): mean={ev_ms.mean():.2f}  median={np.median(ev_ms):.2f}  min={ev_ms.min():.2f}  max={ev_ms.max():.2f}'
        )

        preload_times = []
        eval_times = []
        number_of_basis_fns = []
        hashes = []

        # Iterate samples
        for k, (nuc_pos, atom_z) in enumerate(samples):
            grid, _ = sample_grids[k]

            # Preloading part 2: construct per-sample VecBasisFns
            t_pre, preloaded = _timeit(vec_assembly, atom_z)
            vec_basis = GTOBasis.from_preloaded(preloaded)
            preload_times.append(
                t_build + t_pre if k == 0 else t_pre
            )  # include build time once
            number_of_basis_fns.append(vec_basis.compile_statics.num_basis_fns)
            hashes.append(vec_basis.compile_statics.hash_value)
            print('### hash value: ', vec_basis.compile_statics.hash_value, flush=True)

            # Evaluation: time basis function call (includes JIT on first call for this shape)
            nuc_pos_jnp = jnp.asarray(nuc_pos)
            t_eval, _ = _timeit(
                basis_fn,
                grid,
                nuc_pos_jnp,
                vec_basis.radial_primitives,
                vec_basis.compile_statics,
            )
            eval_times.append(t_eval)
        # Aggregate
        pre_ms = np.array(preload_times) * 1e3
        ev_ms = np.array(eval_times) * 1e3
        print('=== Excluding compile times ===')
        print(
            f'Preload times   (ms): mean={pre_ms.mean():.2f}  median={np.median(pre_ms):.2f}  min={pre_ms.min():.2f}  max={pre_ms.max():.2f}'
        )
        print(
            f'Evaluation times(ms): mean={ev_ms.mean():.2f}  median={np.median(ev_ms):.2f}  min={ev_ms.min():.2f}  max={ev_ms.max():.2f}'
        )
        print(
            f'Basis size       (B): mean={np.mean(number_of_basis_fns):.1f}  median={np.median(number_of_basis_fns):.1f}  min={np.min(number_of_basis_fns)}  max={np.max(number_of_basis_fns)}'
        )
        print(f'number of unique hashes / number of compilations: {len(set(hashes))}')
        # Clear between bases to avoid cross-compilation bias
        jax.clear_caches()
