import threading
from typing import Any, Dict, Sequence, Tuple

import grain.python as grain
import numpy as onp

from deixc.dataset import DEIXCDataset, DEIXCTargets, RawSample
from egxc.dataloading import IndexWrapper, io
from egxc.dataloading.dataloader import (
    DataLoaders,
    DatasetEnsemble,
    get_psys_and_dataloaders,
)
from egxc.dataloading.datasets.base import BaseDataset, SupportsIndex
from egxc.systems.preload import PreloadSystem
from egxc.utils.typing import MethodKey

from .custom import CustomDftTargetGenerator
from .pyscf_based import PyscfDftTargetGenerator


class LazyDEIXCDataset(DEIXCDataset):
    """Dataset wrapper that regenerates DEI-XC auxiliary targets on demand."""

    def __init__(
        self,
        dataset: BaseDataset,
        method_key: MethodKey,
        method_kwargs: Dict[str, Any],
        align_scf_trajectory: int,
        shift_dispersion: bool,
    ):
        """Create a lazy dataset around ``dataset`` with the provided settings."""
        super().__init__(
            dataset,
            method_key,
            method_kwargs,
            align_scf_trajectory,
            shift_dispersion,
        )
        self._generation_lock = threading.Lock()
        self._generating_indices: set[int] = set()

    def __getstate__(self) -> Dict[str, Any]:
        state = self.__dict__.copy()
        state['_generation_lock'] = None
        state['_generating_indices'] = list(self._generating_indices)
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        self.__dict__.update(state)
        self._generation_lock = threading.Lock()
        self._generating_indices = set(state.get('_generating_indices', ()))

    def infer_split(
        self,
        train_fraction: float | None = None,
        val_fraction: float | None = None,
        data_split_seed: int = 0,
    ) -> tuple['LazyDEIXCDataset', 'LazyDEIXCDataset', 'LazyDEIXCDataset']:
        """Ensure derived splits keep the lazy regeneration behavior."""
        train_ds, val_ds, test_ds = super().infer_split(
            train_fraction=train_fraction,
            val_fraction=val_fraction,
            data_split_seed=data_split_seed,
        )
        return tuple(
            LazyDEIXCDataset(
                ds._dataset,  # type: ignore[attr-defined]
                self.deixc_ref_method_key,
                self.deixc_ref_method_kwargs,
                self.align_scf_trajectory,
                self.shift_dispersion,
            )
            for ds in (train_ds, val_ds, test_ds)
        )

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        """Return raw sample, regenerating auxiliary targets if they are missing."""
        idx0, raw_input, _ = self._dataset[idx]
        aux_dir = io.auxiliary_data_directory(
            self.auxiliary_data_directory,
            'deixc',
            self.deixc_ref_method_key,
            **self.deixc_ref_method_kwargs,
        )
        if not io.auxiliary_data_exists(aux_dir, idx0):
            self._compute_deixc_target(int(idx), int(idx0), aux_dir)
        path = io.auxiliary_data_path(aux_dir, idx0)
        data = onp.load(path, allow_pickle=True)
        targets = DEIXCTargets.create(
            data,
            self.align_scf_trajectory,
            self.shift_dispersion,
        )
        return idx0, raw_input, targets

    def _compute_deixc_target(
        self, local_idx: int, sample_idx: int, aux_dir: str
    ) -> None:
        """Compute the DEI-XC targets for ``sample_idx``."""
        with self._generation_lock:
            if sample_idx in self._generating_indices:
                raise FileNotFoundError(
                    f'DEI-XC targets for sample {sample_idx} are currently being generated.'
                )
            subset = IndexWrapper(self._dataset, [local_idx])
            method_kwargs = dict(self.deixc_ref_method_kwargs)
            method_kwargs.setdefault('spin_restricted', True)
            backend = method_kwargs.get('backend', 'custom')
            if backend == 'custom':
                generator = CustomDftTargetGenerator(subset, **method_kwargs)
            elif backend == 'pyscf':
                generator = PyscfDftTargetGenerator(subset, **method_kwargs)
            else:
                raise ValueError(f'Unknown DEI-XC backend: {backend}')
            if hasattr(generator, 'workers'):
                setattr(generator, 'workers', 0)  # type: ignore[attr-defined]
            if hasattr(generator, 'worker_buffer_size'):
                setattr(generator, 'worker_buffer_size', 1)  # type: ignore[attr-defined]
            self._generating_indices.add(sample_idx)
            try:
                generator(0, 1)
            finally:
                self._generating_indices.discard(sample_idx)


def get_psys_and_lazy_dataloader(
    datasets: DatasetEnsemble,
    transformations: Sequence[grain.Transformation],
    shuffle: bool,
    shuffling_seed: int,
    n_test_samples: int | None = None,
) -> Tuple[PreloadSystem, DataLoaders]:
    return get_psys_and_dataloaders(
        datasets=datasets,
        transformations=transformations,
        shuffle=shuffle,
        shuffling_seed=shuffling_seed,
        workers=0,
        worker_buffer_size=1,
        n_test_samples=n_test_samples,
    )
