import os
from time import time
from typing import Any, Dict, NamedTuple, Set, Tuple

import numpy as onp
from grain._src.python.data_sources import SupportsIndex
from grain.python import RandomAccessDataSource

from egxc.dataloading import io
from egxc.utils.typing import MethodKey, NpFloatAx3, NpFloatBxB, NpUIntA

Data = Any
Charge = int
Spin = int


class Targets(NamedTuple):
    energy: float | None  #  should be in hartree
    nuc_forces: NpFloatAx3 | None  # should be in hartree / angstrom
    density_matrix: NpFloatBxB | None  # TODO: adapt to allow for unrestricted spin


RawInput = Tuple[NpFloatAx3, NpUIntA, Charge, Spin, NpFloatBxB | None]
RawSample = Tuple[SupportsIndex, RawInput, Targets]


class BaseDataset(RandomAccessDataSource):
    data: Data
    directory: str
    unique_elements: Set[int]

    energy_unit: str
    distance_unit: str

    initial_ref_density_method_key: MethodKey | None
    initial_ref_density_method_kwargs: Dict[str, Any] | None

    def init_params(
        self,
        directory: str,
        energy_unit: str,
        distance_unit: str,
        initial_ref_density_method_key: MethodKey | None = None,
        initial_ref_density_method_kwargs: Dict[str, Any] | None = None,
    ) -> None:
        self.directory = directory
        self.energy_unit = energy_unit
        self.distance_unit = distance_unit
        self.initial_ref_density_method_key = initial_ref_density_method_key
        self.initial_ref_density_method_kwargs = initial_ref_density_method_kwargs

    def copy_params_from_dataset(self, dataset: 'BaseDataset') -> None:
        self.directory = dataset.directory
        self.unique_elements = dataset.unique_elements
        self.energy_unit = dataset.energy_unit
        self.distance_unit = dataset.distance_unit
        self.initial_ref_density_method_key = dataset.initial_ref_density_method_key
        self.initial_ref_density_method_kwargs = dataset.initial_ref_density_method_kwargs

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        """
        This abstract method needs to be implemented by the child classes
        as if it were a __getitem__ method of a singular dataset.
        Which might be counterintuitive, for the split datasets.
        """
        raise NotImplementedError

    def sort_atoms(
        self, nuc_pos: NpFloatAx3, atom_z: NpUIntA, nuc_forces: NpFloatAx3 | None = None
    ) -> Tuple:
        order = onp.argsort(
            atom_z, stable=True
        )  # e.g. [1, 8, 6, 1, 1] -> [1, 1, 1, 6, 8]
        nuc_pos = nuc_pos[order]
        atom_z = atom_z[order]
        if nuc_forces is not None:
            nuc_forces = nuc_forces[order]
            return nuc_pos, atom_z, nuc_forces
        else:
            return nuc_pos, atom_z

    def __len__(self) -> int:
        return len(self.data)

    def timed_get_item(self, idx: int) -> float:
        time_start = time()
        self[idx]
        time_end = time()
        return time_end - time_start

    @property
    def auxiliary_data_directory(self) -> str:
        return os.path.join(self.directory, 'aux')

    def load_aux(self, idx: SupportsIndex) -> NpFloatBxB | None:
        """
        Load a density sample from a file.
        """
        if self.initial_ref_density_method_key is not None:
            assert self.initial_ref_density_method_kwargs is not None
            density_dir = io.auxiliary_data_directory(
                self.auxiliary_data_directory,
                'initial_guess',
                self.initial_ref_density_method_key,
                **self.initial_ref_density_method_kwargs,
            )
            path = os.path.join(density_dir, f'{idx}.npz')
            data = onp.load(path)
            return onp.asarray(data['density_matrix'])
        else:
            return None

    def infer_split(self, **kwargs) -> Tuple['BaseDataset', 'BaseDataset', 'BaseDataset']:
        raise NotImplementedError(
            'This method should be implemented in subclasses that support splitting.'
        )


class PresplitDataset(BaseDataset):
    def split(self) -> Tuple[BaseDataset, BaseDataset, BaseDataset]:
        """
        Abstract method to return the train, val, and test datasets
        """
        raise NotImplementedError


class PartiallySplitDataset(BaseDataset):
    """
    Datasets with preexisting splits for train and test sets
    but without a validation set.
    """

    def random_split(
        self, val_fraction: float, seed: int
    ) -> Tuple[BaseDataset, BaseDataset, BaseDataset]:
        """
        Abstract method to split the dataset into train, val, and test sets
        """
        raise NotImplementedError


class UnsplitDataset(BaseDataset):
    def random_split(
        self, train_fraction: float, val_fraction: float, seed: int
    ) -> Tuple[BaseDataset, BaseDataset, BaseDataset]:
        """
        Abstract method to split the dataset into train, val, and test sets
        """
        raise NotImplementedError
