from typing import Any, Dict

import numpy as onp
from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from deixc import data_generation
from egxc.dataloading import BaseDataset, key_to_dataset
from egxc.dataloading.utils import IndexWrapper
from egxc.utils.typing import MethodKey

ex = Experiment()

config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
compilation_cache.set_cache_dir('./caches/jax_compile/compute_deixc_data')

DEFAULT_DATA_KWARGS = {
    'initial_ref_density_method_key': None,
    'initial_ref_density_method_kwargs': None,
}


@ex.config
def base_config():
    data_dir = ''  # custom
    dataset = {
        'n_samples': None,
        'seed': 0,  # used for random sub-sampling
        'kwargs': {
            'data_dir': data_dir,
            **DEFAULT_DATA_KWARGS,
        },
    }


@ex.named_config
def dft():
    method = {
        'key': 'ks_dft',
        'kwargs': {  # method_specific kwargs
            'xc_str': 'wB97M-V',
            'basis': 'def2-TZVPD',
            'backend': 'pyscf',  # pyscf or custom
            'use_eri_density_fitting': True,
            'use_exchange_density_fitting': True,
            'spin_restricted': True,
            'quadrature_grid_level': 1,
        },
    }


@ex.named_config
def qm9():
    dataset = {
        'key': 'qm9',
        'kwargs': {
            'heavy_atoms_thresh': 4,
            'exclude_fluorine': False,
        },
    }


@ex.named_config
def md17():
    dataset = {
        'key': 'md17',
        'kwargs': {
            'name': 'ethanol',  # aspirin, benzene, ethanol, malonaldehyde, toluene
        },
    }


@ex.named_config
def qmugs():
    dataset = {
        'key': 'qmugs',
        'kwargs': {
            'custom_index_file_name': None,  # can be set to load custom indices
        },
    }


@ex.named_config
def qm40():
    dataset = {
        'key': 'qm40',
        'kwargs': {
            # Optional subsampling: limit to N molecules per heavy-atom bin.
            # (This is applied after element filtering, before any global n_samples.)
            'samples_per_heavy_atom_bin': None,
            # Optional element filtering: exclude molecules containing any of these Z.
            # QM40 only contains: H(1), C(6), N(7), O(8), F(9), S(16), Cl(17)
            # Example: [9, 16, 17] for F, S, Cl (non-CHNO elements)
            'exclude_elements': None,
        },
    }


@ex.named_config
def threebpa():
    dataset = {'key': '3bpa', 'kwargs': {}}


@ex.capture(prefix='dataset')  # type: ignore
def get_dataset(
    key: str, n_samples: int | None, seed: int, kwargs: Dict[str, Any]
) -> BaseDataset:
    def sub_sample(dataset: BaseDataset, n_samples: int) -> BaseDataset:
        indices = onp.random.RandomState(seed).permutation(len(dataset))[:n_samples]
        return IndexWrapper(dataset, indices.tolist())

    dataset = key_to_dataset[key](**kwargs)
    if key == 'qm9':
        dataset, _, _ = dataset.random_split(
            val_fraction=0, seed=seed
        )  # isolate heavy atoms
    elif key == 'md17':
        # MD17 is already split into train/test, use train split for data generation
        dataset, _, _ = dataset.random_split(val_fraction=0, seed=seed)
    elif key == 'qmugs':
        # QMugs uses custom indices directly, no splitting needed
        pass
    elif key == 'qm40':
        # QM40 is UnsplitDataset - use same split as evaluation to ensure consistency
        # train_fraction=0.01, val_fraction=0.01 leaves 98% for test
        _, _, dataset = dataset.random_split(
            train_fraction=0.01, val_fraction=0.01, seed=seed
        )
    if n_samples is not None:
        dataset = sub_sample(dataset, n_samples)
    return dataset


class ExperimentWrapper:
    @ex.capture
    def __init__(
        self,
        n_chunks: int,
        chunk_id: int,
    ) -> None:
        self.dataset = get_dataset()  # type: ignore
        self.n_chunks = n_chunks
        self.chunk_id = chunk_id
        self.init_generator()  # type: ignore

    @ex.capture(prefix='method')  # type: ignore
    def init_generator(self, key: MethodKey, kwargs: Dict[str, Any]) -> None:
        workers = kwargs.get('workers', None)
        generator_kwargs = {k: v for k, v in kwargs.items() if k != 'workers'}
        match key:
            case 'hf':
                raise NotImplementedError('HF method is not implemented yet.')
            case 'ks_dft':
                backend = generator_kwargs['backend']
                if backend == 'pyscf':
                    self.generator = data_generation.PyscfDftTargetGenerator(
                        self.dataset,
                        **generator_kwargs,
                    )
                elif backend == 'custom':
                    self.generator = data_generation.CustomDftTargetGenerator(
                        self.dataset,
                        **generator_kwargs,
                    )
                    if workers is not None:
                        self.generator.workers = workers
                else:
                    raise ValueError(f'Unknown backend: {backend}')
            case 'ccsd':
                raise NotImplementedError('CCSD method is not implemented yet.')
            case _:
                raise ValueError(f'Unknown method: {key}')

    def __call__(self) -> None:
        self.generator(self.chunk_id, self.n_chunks)


@ex.automain
def main():
    exp = ExperimentWrapper()  # type: ignore
    exp()
