"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class BCTrainConfig:
    # self-designed args
    

    # wandb params
    project: str = "MOSDB-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "BC"
    suffix: Optional[str] = ""
    log_root_dir: Optional[str] = "../benchmark"
    verbose: bool = True

    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float, float, float], ...] = None
    epsilon: float = None
    density: float = 1.0

    # training params
    task: str = "FreightFrankaCloseDrawer"
    dataset: str = None
    seed: int = 0
    device: str = "cuda:0"
    threads: int = 4
    actor_lr: float = 0.001
    cost_limit: int = 25
    episode_len: int = 1000
    batch_size: int = 64
    num_workers: int = 8
    bc_mode: str = "all"  # "all", "safe", "risky", "frontier", "boundary", "multi-task"
    update_steps: int = 100000
    centralized_training: bool = True

    # model params
    a_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    gamma: float = 1.0

    # evaluation params
    eval_episodes: int = 1
    eval_every: int = 2500
    save_model: bool = False


@dataclass
class BCSafetyAntMultiGoal1v0Config(BCTrainConfig):
    pass

@dataclass
class BCSafetyPointMultiGoal1v0Config(BCTrainConfig):
    pass

@dataclass
class BCSafetyAntMultiGoal2v0Config(BCTrainConfig):
    pass

@dataclass
class BCSafetyPointMultiGoal2v0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety2x4AntVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety4x2AntVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety2x3HalfCheetahVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety6x1HalfCheetahVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety2x3Walker2dVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety3x1HopperVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety2x1SwimmerVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCSafety98HumanoidVelocityv0Config(BCTrainConfig):
    pass

@dataclass
class BCShadowHandOverSafejointConfig(BCTrainConfig):
    pass

@dataclass
class BCShadowHandOverSafefingerConfig(BCTrainConfig):
    pass

@dataclass
class BCShadowHandCatchOver2UnderarmSafejointConfig(BCTrainConfig):
    pass

@dataclass
class BCShadowHandCatchOver2UnderarmSafefingerConfig(BCTrainConfig):
    pass

@dataclass
class BCFreightFrankaCloseDrawerConfig(BCTrainConfig):
    pass

@dataclass
class BCFreightFrankaPickAndPlaceConfig(BCTrainConfig):
    pass


BC_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "SafetyAntMultiGoal1-v0": BCSafetyAntMultiGoal1v0Config,
    "SafetyPointMultiGoal1-v0": BCSafetyPointMultiGoal1v0Config,
    "SafetyAntMultiGoal2-v0": BCSafetyAntMultiGoal2v0Config,
    "SafetyPointMultiGoal2-v0": BCSafetyPointMultiGoal2v0Config,
    # safety_gymnasium
    "Safety2x4AntVelocity-v0": BCSafety2x4AntVelocityv0Config,
    "Safety4x2AntVelocity-v0": BCSafety4x2AntVelocityv0Config,
    "Safety2x3HalfCheetahVelocity-v0": BCSafety2x3HalfCheetahVelocityv0Config,
    "Safety6x1HalfCheetahVelocity-v0": BCSafety6x1HalfCheetahVelocityv0Config,
    "Safety2x3Walker2dVelocity-v0": BCSafety2x3Walker2dVelocityv0Config,
    "Safety3x1HopperVelocity-v0": BCSafety3x1HopperVelocityv0Config,
    "Safety2x1SwimmerVelocity-v0": BCSafety2x1SwimmerVelocityv0Config,
    "Safety98HumanoidVelocity-v0": BCSafety98HumanoidVelocityv0Config,
    # safe_isaac_gym
    "ShadowHandOver_Safe_joint": BCShadowHandOverSafejointConfig,
    "ShadowHandOver_Safe_finger": BCShadowHandOverSafefingerConfig,
    "ShadowHandCatchOver2Underarm_Safe_joint": BCShadowHandCatchOver2UnderarmSafejointConfig,
    "ShadowHandCatchOver2Underarm_Safe_finger": BCShadowHandCatchOver2UnderarmSafefingerConfig,
    "FreightFrankaCloseDrawer": BCFreightFrankaCloseDrawerConfig,
    "FreightFrankaPickAndPlace": BCFreightFrankaPickAndPlaceConfig,
}