import warnings
from typing import Literal

from deixc.data_generation.utils.targets import DeiXCTargets
from egxc.data_generation.generator import BaseGenerator
from egxc.dataloading import io


class BaseDeixcTargetGenerator(BaseGenerator):
    """
    Base class for DEI-XC target generators.
    """

    def save_deixc_targets(self, idx: int, data: DeiXCTargets) -> None:
        io.auxiliary_data_save(self.aux_dir, idx, data.to_dict())


class BaseDftDeixcTargetGenerator(BaseDeixcTargetGenerator):
    """
    Base class for DFT-based DEI-XC target generators.
    """

    backend: Literal['pyscf', 'custom']
    method_key = 'ks_dft'

    def __post_init__(self):
        warnings.warn('DFT data generation presently assumes spin-restriction')
