"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class BEARLTrainConfig:
    # self-designed args
    

    # wandb params
    project: str = "MOSDB-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "BEARL"
    suffix: Optional[str] = ""
    log_root_dir: Optional[str] = "../benchmark"
    verbose: bool = True

    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float, float, float], ...] = None
    epsilon: float = None
    density: float = 1.0

    # training params
    task: str = "FreightFrankaCloseDrawer"
    dataset: str = None
    seed: int = 0
    device: str = "cuda:0"
    threads: int = 4
    reward_scale: float = 0.1
    cost_scale: float = 1
    actor_lr: float = 0.001
    critic_lr: float = 0.001
    vae_lr: float = 0.001
    cost_limit: int = 25
    episode_len: int = 1000
    batch_size: int = 64
    num_workers: int = 8
    update_steps: int = 100000
    train_epoch: int = 1
    centralized_training: bool = True

    # model params
    a_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    c_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    vae_hidden_sizes: int = 400
    sample_action_num: int = 10
    gamma: float = 0.99
    tau: float = 0.005
    beta: float = 0.5
    lmbda: float = 0.75
    mmd_sigma: float = 50
    target_mmd_thresh: float = 0.05
    num_samples_mmd_match: int = 10
    start_update_policy_step: int = 0
    kernel: str = "gaussian"  # or "laplacian"
    num_q: int = 2
    num_qc: int = 2
    PID: List[float] = field(default=[0.1, 0.003, 0.001], is_mutable=True)

    # evaluation params
    eval_episodes: int = 1
    eval_every: int = 2500
    save_model: bool = False


@dataclass
class BEARLSafetyAntMultiGoal1v0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafetyPointMultiGoal1v0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafetyAntMultiGoal2v0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafetyPointMultiGoal2v0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety2x4AntVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety4x2AntVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety2x3HalfCheetahVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety6x1HalfCheetahVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety2x3Walker2dVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety3x1HopperVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety2x1SwimmerVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLSafety98HumanoidVelocityv0Config(BEARLTrainConfig):
    pass

@dataclass
class BEARLShadowHandOverSafejointConfig(BEARLTrainConfig):
    pass

@dataclass
class BEARLShadowHandOverSafefingerConfig(BEARLTrainConfig):
    pass

@dataclass
class BEARLShadowHandCatchOver2UnderarmSafejointConfig(BEARLTrainConfig):
    pass

@dataclass
class BEARLShadowHandCatchOver2UnderarmSafefingerConfig(BEARLTrainConfig):
    pass

@dataclass
class BEARLFreightFrankaCloseDrawerConfig(BEARLTrainConfig):
    pass

@dataclass
class BEARLFreightFrankaPickAndPlaceConfig(BEARLTrainConfig):
    pass


BEARL_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "SafetyAntMultiGoal1-v0": BEARLSafetyAntMultiGoal1v0Config,
    "SafetyPointMultiGoal1-v0": BEARLSafetyPointMultiGoal1v0Config,
    "SafetyAntMultiGoal2-v0": BEARLSafetyAntMultiGoal2v0Config,
    "SafetyPointMultiGoal2-v0": BEARLSafetyPointMultiGoal2v0Config,
    # safety_gymnasium
    "Safety2x4AntVelocity-v0": BEARLSafety2x4AntVelocityv0Config,
    "Safety4x2AntVelocity-v0": BEARLSafety4x2AntVelocityv0Config,
    "Safety2x3HalfCheetahVelocity-v0": BEARLSafety2x3HalfCheetahVelocityv0Config,
    "Safety6x1HalfCheetahVelocity-v0": BEARLSafety6x1HalfCheetahVelocityv0Config,
    "Safety2x3Walker2dVelocity-v0": BEARLSafety2x3Walker2dVelocityv0Config,
    "Safety3x1HopperVelocity-v0": BEARLSafety3x1HopperVelocityv0Config,
    "Safety2x1SwimmerVelocity-v0": BEARLSafety2x1SwimmerVelocityv0Config,
    "Safety98HumanoidVelocity-v0": BEARLSafety98HumanoidVelocityv0Config,
    # safe_isaac_gym
    "ShadowHandOver_Safe_joint": BEARLShadowHandOverSafejointConfig,
    "ShadowHandOver_Safe_finger": BEARLShadowHandOverSafefingerConfig,
    "ShadowHandCatchOver2Underarm_Safe_joint": BEARLShadowHandCatchOver2UnderarmSafejointConfig,
    "ShadowHandCatchOver2Underarm_Safe_finger": BEARLShadowHandCatchOver2UnderarmSafefingerConfig,
    "FreightFrankaCloseDrawer": BEARLFreightFrankaCloseDrawerConfig,
    "FreightFrankaPickAndPlace": BEARLFreightFrankaPickAndPlaceConfig,
}
