from typing import Any, Dict

import numpy as onp
from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from egxc.data_generation.generator import PyscfDftTargetGenerator
from egxc.dataloading import BaseDataset, key_to_dataset
from egxc.dataloading.utils import IndexWrapper
from egxc.utils.typing import MethodKey

ex = Experiment()

config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
compilation_cache.set_cache_dir('./caches/jax_compile/compute_reference_densities')

DEFAULT_DATA_KWARGS = {
    'initial_ref_density_method_key': None,
    'initial_ref_density_method_kwargs': None,
}


@ex.config
def base_config():
    data_dir = ''  # custom
    dataset = {
        'n_samples': None,
        'seed': 0,  # used for random sub-sampling
        'kwargs': {
            'data_dir': data_dir,
            **DEFAULT_DATA_KWARGS,
        },
    }


@ex.named_config
def dft():
    method = {
        'key': 'ks_dft',
        'kwargs': {  # method_specific kwargs
            'xc_str': 'wB97M-V',
            'basis': 'def2-TZVPD',
            'backend': 'pyscf',  # pyscf or custom
            'use_eri_density_fitting': True,
            'use_exchange_density_fitting': True,
            'spin_restricted': True,
            'quadrature_grid_level': 3,
        },
    }


@ex.named_config
def qm9():
    dataset = {
        'key': 'qm9',
        'kwargs': {
            'heavy_atoms_thresh': 4,
            'exclude_fluorine': False,
        },
    }


@ex.named_config
def threebpa():
    dataset = {'key': '3bpa', 'kwargs': {}}


@ex.capture(prefix='dataset')  # type: ignore
def get_dataset(
    key: str, n_samples: int | None, seed: int, kwargs: Dict[str, Any]
) -> BaseDataset:
    def sub_sample(dataset: BaseDataset, n_samples: int) -> BaseDataset:
        indices = onp.random.RandomState(seed).permutation(len(dataset))[:n_samples]
        return IndexWrapper(dataset, indices.tolist())

    dataset = key_to_dataset[key](**kwargs)
    if key == 'qm9':
        dataset, _, _ = dataset.random_split(
            val_fraction=0, seed=seed
        )  # isolate heavy atoms
    if n_samples is not None:
        dataset = sub_sample(dataset, n_samples)
    return dataset


class ExperimentWrapper:
    @ex.capture
    def __init__(
        self,
        n_chunks: int,
        chunk_id: int,
    ) -> None:
        self.dataset = get_dataset()  # type: ignore
        self.n_chunks = n_chunks
        self.chunk_id = chunk_id
        self.init_generator()  # type: ignore

    @ex.capture(prefix='method')  # type: ignore
    def init_generator(self, key: MethodKey, kwargs: Dict[str, Any]) -> None:
        match key:
            case 'hf':
                raise NotImplementedError('HF method is not implemented yet.')
            case 'ks_dft':
                backend = kwargs['backend']
                if backend == 'pyscf':
                    self.generator = PyscfDftTargetGenerator(
                        self.dataset,
                        **kwargs,
                    )
                elif backend == 'custom':
                    raise NotImplementedError('Custom backend is not implemented yet.')
                else:
                    raise ValueError(f'Unknown backend: {backend}')
            case 'ccsd':
                raise NotImplementedError('CCSD method is not implemented yet.')
            case _:
                raise ValueError(f'Unknown method: {key}')

    def __call__(self) -> None:
        self.generator(self.chunk_id, self.n_chunks)


@ex.automain
def main():
    exp = ExperimentWrapper()  # type: ignore
    exp()
