from dataclasses import dataclass
from typing import Any, List, Tuple

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING


@dataclass
class AIGDatasetConfig:
    _target_: str = "data.AIG_Dataset_Collection"
    aigs: Any = MISSING
    return_action_mask: bool = True
    reward_type: str = "simple"
    const_node: bool = True
    gamma: float = 0.99
    embedding_size: int = MISSING
    num_workers: int = 1
    fragments: int = 1
    rebuild_cache: bool = False


@dataclass
class TTDatasetConfig:
    _target_: str = "data.TT_Dataset"
    path: str = "/data/truth_tables/**/*.txt"


@dataclass
class StorageConfig:
    _target_: str = MISSING
    max_size: int = 10000000
    scratch_dir: str = ""


@dataclass
class ListStorageConfig(StorageConfig):
    _target_: str = "torchrl.data.ListStorage"


@dataclass
class LazyMemmapStorageConfig(StorageConfig):
    _target_: str = "torchrl.data.LazyMemmapStorage"


@dataclass
class SharedLazyMemmapStorageConfig(StorageConfig):
    _target_: str = "rl.replay_buffer.SharedLazyMemmap"
    scratch_dir: str = "/"


@dataclass
class ReplayBufferConfig:
    _target_: str = "torchrl.data.ReplayBuffer"
    storage: StorageConfig = MISSING
    batch_size: int = MISSING
    prefetch: int = 10


@dataclass
class TensorDictReplayBufferConfig(ReplayBufferConfig):
    _target_: str = "torchrl.data.TensorDictReplayBuffer"


def register_data_configs() -> None:
    cs = ConfigStore.instance()
    cs.store(
        group="data",
        name="base_aig_dataset",
        node=AIGDatasetConfig,
    )
    cs.store(
        group="data",
        name="base_tt_dataset",
        node=TTDatasetConfig,
    )
    cs.store(
        group="data",
        name="base_storage",
        node=StorageConfig,
    )
    cs.store(
        group="data/storage",
        name="base_list_storage",
        node=ListStorageConfig,
    )
    cs.store(
        group="data/storage",
        name="base_lazy_memmap_storage",
        node=LazyMemmapStorageConfig,
    )
    cs.store(
        group="data/storage",
        name="base_shared_lazy_memmap_storage",
        node=SharedLazyMemmapStorageConfig,
    )
    cs.store(
        group="data",
        name="base_replay_buffer",
        node=ReplayBufferConfig,
    )
    cs.store(
        group="data",
        name="base_tensordict_replay_buffer",
        node=TensorDictReplayBufferConfig,
    )
    cs.store(
        group="data",
        name="base_db_config",
        node=DBConfig,
    )
