# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the CC BY-NC 4.0 license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
from pathlib import Path

from metamotivo.data_loading.base import BaseDataConfig


class DMCDataConfig(BaseDataConfig):
    name: tp.Literal["dmc"] = "dmc"
    domain: tp.Literal["walker", "cheetah", "quadruped", "pointmass"]
    dataset_expl_agent: str = "rnd"
    load_n_episodes: int = 5_000

    def build(self, buffer_device, batch_size, frame_stack, relabel_fn=None) -> tp.Dict:
        domain = self.domain
        if domain == "pointmass":
            domain = "point_mass_maze"
        path = Path(self.dataset_root) / f"{domain}/{self.dataset_expl_agent}/buffer"
        return super().build_from_path(path, buffer_device, batch_size, frame_stack, relabel_fn)
