from time import sleep, time
from typing import Any, Dict, Literal

from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from deixc.dataset import DEIXCDataset
from egxc import dataloading
from egxc.discretization import get_gto_preloader
from egxc.systems import PreloadSystem
from egxc.utils.typing import Alignment

compilation_cache.set_cache_dir('./caches/jax_compile/benchmark')
# jax.config.update("jax_explain_cache_misses", True)

config.update('jax_platform_name', 'gpu')
config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
# config.update('jax_debug_nans', True)

CheckpointType = Literal['scratch', 'static_pretrain', 'static_train', 'dynamic_train']
ex = Experiment()


DATA_DIR = 'ANONYMOUS_DIR'


@ex.config
def default_config():
    run_seed = 0
    expected_device_step_time = 0.22  # seconds
    heavy_atoms_thresh = 5
    workers = 8
    worker_buffer_size = 1
    basis_name = 'def2-TZVPD'
    functional = 'B3LYP'

    """
    Benchmark 27.09.25: BASE 0 ms / GRID (pyscf) 300 ms / FOCK (pyscf) 60 ms / Fock + GRID (pyscf) 600 ms
    - def2-SVP
    - heavy_atoms_thresh = 7
    - workers = 8
    - expected_device_step_time = 0.3 seconds

    Benchmark 09.09.25: FOCK (pyscf) 20 ms
    - expected_device_step_time = 0.3 seconds
    Benchmark 09.09.25: FOCK (pyscf) 100 ms
    - expected_device_step_time = 0.22 seconds

    Benchmark 12.09.25: FOCK (pyscf) 200 ms
    - def2-TZVPD / B3LYP
    - expected_device_step_time = 0.22 seconds
    - heavy_atoms_thresh = 5
    - workers = 8
    - expected_device_step_time = 0.3 seconds
    Benchmark 12.09.25: FOCK (pyscf) 10 ms
    - def2-SVP / B3LYP
    """

    base = {
        'seed': run_seed,
        'test': False,
        'use_density_fitting': True,
        'spin_restricted': True,
    }
    alignment = {
        'atom': 1,
        'basis': 1,
        'grid': 512,
    }
    basis = {
        'name': basis_name,
        'derivative': 1,
        'backend': 'pyscf',  # jax or pyscf  # TODO: fix jax backend
    }
    quadrature = {
        'level': 1,
    }
    initial_density_guess = {
        'key': 'minao',
        'initial_ref_density_method': 'dft_lda',
        'initial_ref_density_basis': basis_name,
        'min_ref_density_interpolation': 0.5,
        'max_ref_density_interpolation': 1.0,
        'noise_eps': 0.0,
    }
    data = {
        'key': 'qm9',
        'data_set_kwargs': {
            'data_dir': DATA_DIR,
            'heavy_atoms_thresh': heavy_atoms_thresh,
            'exclude_fluorine': True,
        },
        'workers': workers,
        'worker_buffer_size': worker_buffer_size,
        'batch_size': 1,
        'shuffle': True,
        'seed': 0,  # random seed for splitting the dataset
        'preload': {
            'center': False,
        },
        'split': {'val_fraction': 0.1},
        'deixc_kwargs': {
            'method': 'dft_' + functional,
            'basis': basis_name,
            'align_scf_trajectory': 8,
            'shift_dispersion': False,
        },
    }


class ExperimentWrapper:
    seed: int
    test: bool
    ao_backend: Literal['pyscf', 'jax']
    # physics model
    basis_str: str
    basis_derivative: int
    use_density_fitting: bool  # density fitted or exact
    # data
    dataloaders: dataloading.DataLoaders
    init_psys: PreloadSystem
    # padding
    alignment: Alignment
    # initial guess
    initial_density_guess_key: Literal['minao']
    initial_ref_density_method: str | None
    initial_ref_density_basis: str | None
    interpolation_min: float
    noise_eps: float
    # solver
    grid_level: int  # quadrature
    spin_restricted: bool
    data_split_seed: int

    def __init__(self):
        self.set_base()  # type: ignore
        self.set_alignment()  # type: ignore
        self.set_quadrature()  # type: ignore
        self.set_basis()  # type: ignore
        self.init_dataloading()  # type: ignore

    @ex.capture(prefix='base')  # type: ignore
    def set_base(
        self,
        test: bool,
        seed: int,
        use_density_fitting: bool,
        spin_restricted: bool,
    ):
        self.test = test
        self.seed = seed
        self.use_density_fitting = use_density_fitting
        self.spin_restricted = spin_restricted

    @ex.capture(prefix='alignment')  # type: ignore
    def set_alignment(self, atom: int, basis: int, grid: int) -> None:
        self.alignment = Alignment(atom, basis, grid)

    @ex.capture(prefix='quadrature')  # type: ignore
    def set_quadrature(self, level: int) -> None:  # called by init_input_transform
        self.grid_level = level

    @ex.capture(prefix='basis')  # type: ignore
    def set_basis(
        self, name: str, derivative: int, backend: Literal['pyscf', 'jax']
    ) -> None:
        self.basis_str = name
        self.basis_derivative = derivative
        self.ao_backend = backend

    @ex.capture(prefix='initial_density_guess')  # type: ignore
    def init_dataloading(
        self,
        key: Literal['minao'],
        initial_ref_density_method: str | None,
        initial_ref_density_basis: str | None,
        min_ref_density_interpolation: float,
        max_ref_density_interpolation: float,
    ) -> None:
        assert (initial_ref_density_method is None) == (
            initial_ref_density_basis is None
        ), 'reference method and basis must be provided together or not at all'
        if initial_ref_density_method is None:
            assert min_ref_density_interpolation == 0.0, (
                'min_ref_density_interpolation must be 0.0 when no reference density is provided'
            )
            assert max_ref_density_interpolation == 0.0, (
                'max_ref_density_interpolation must be 0.0 when no reference density is provided'
            )
        self.initial_density_guess_key = key
        self.initial_ref_density_method = initial_ref_density_method
        self.initial_ref_density_basis = initial_ref_density_basis
        self._init_dataset()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataset(
        self,
        key: str,
        data_set_kwargs: Dict[str, Any],
        deixc_kwargs: Dict[str, Any],
    ) -> None:
        dataset: dataloading.BaseDataset = dataloading.key_to_dataset[key.lower()](
            initial_ref_density_method=self.initial_ref_density_method,
            initial_ref_density_basis=self.initial_ref_density_basis,
            **data_set_kwargs,
        )
        self.method = deixc_kwargs['method']
        self.dataset = DEIXCDataset(dataset, **deixc_kwargs)
        self._init_dataloader()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataloader(
        self,
        split,
        batch_size: int,
        shuffle: bool,
        preload: Dict[str, bool],
        workers: int | None,
        worker_buffer_size: int,
        seed: int,
    ) -> None:
        self.data_split_seed = seed
        dataset_ensemble = dataloading.DatasetEnsemble.infer_split(
            self.dataset, data_split_seed=self.data_split_seed, **split
        )
        self.max_angular_momentum, basis_fn_preloader = get_gto_preloader(
            self.basis_str, self.dataset.unique_elements
        )
        preload_transformations = dataloading.get_preload_transform(
            batch_size,
            self.basis_str,
            self.spin_restricted,
            self.alignment,
            self.use_density_fitting,
            base_initial_density_guess=self.initial_density_guess_key,
            basis_fn_preloader=basis_fn_preloader,
            **preload,
        )
        self.init_psys, self.dataloaders = dataloading.get_psys_and_dataloaders(
            dataset_ensemble,
            preload_transformations,
            shuffle,
            workers,
            worker_buffer_size,
            seed,
        )
        self.init_pvec_basis_fns = basis_fn_preloader(self.init_psys.atom_z)

    @ex.capture
    def __call__(self, expected_device_step_time: float) -> None:
        for epoch in range(10):
            print(f'Epoch {epoch}')
            running_avg_time = 0.0
            t0 = time()
            for psys, preloaded_basis_fns, targets in self.dataloaders.train:
                t1 = time()
                sleep(expected_device_step_time)
                running_avg_time = 0.9 * running_avg_time + 0.1 * (t1 - t0)
                # print(f'get_item: {t1 - t0:.2f}s, running_avg_time: {running_avg_time:.2f}s', flush=True)
                print(f'running_avg_time: {running_avg_time:.2f}s', flush=True)
                t0 = time()


@ex.automain
def main():
    exp = ExperimentWrapper()
    exp()  # type: ignore
