"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class CPQTrainConfig:
    # self-designed args
    

    # wandb params
    project: str = "MOSDB-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "CPQ"
    suffix: Optional[str] = ""
    log_root_dir: Optional[str] = "../benchmark"
    verbose: bool = True

    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float, float, float], ...] = None
    epsilon: float = None
    density: float = 1.0

    # training params
    task: str = "FreightFrankaCloseDrawer"
    dataset: str = None
    seed: int = 0
    device: str = "cuda:0"
    threads: int = 4
    reward_scale: float = 0.1
    cost_scale: float = 1
    actor_lr: float = 0.0001
    critic_lr: float = 0.001
    alpha_lr: float = 0.0001
    vae_lr: float = 0.001
    cost_limit: int = 25
    episode_len: int = 1000
    batch_size: int = 64
    num_workers: int = 8
    update_steps: int = 100000
    train_epoch: int = 1
    centralized_training: bool = True

    # model params
    a_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    c_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    vae_hidden_sizes: int = 400
    sample_action_num: int = 10
    gamma: float = 0.99
    tau: float = 0.005
    beta: float = 0.5
    num_q: int = 2
    num_qc: int = 2
    qc_scalar: float = 1.5
    
    # evaluation params
    eval_episodes: int = 1
    eval_every: int = 2500
    save_model: bool = False


@dataclass
class CPQSafetyAntMultiGoal1v0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafetyPointMultiGoal1v0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafetyAntMultiGoal2v0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafetyPointMultiGoal2v0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety2x4AntVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety4x2AntVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety2x3HalfCheetahVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety6x1HalfCheetahVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety2x3Walker2dVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety3x1HopperVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety2x1SwimmerVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQSafety98HumanoidVelocityv0Config(CPQTrainConfig):
    pass

@dataclass
class CPQShadowHandOverSafejointConfig(CPQTrainConfig):
    pass

@dataclass
class CPQShadowHandOverSafefingerConfig(CPQTrainConfig):
    pass

@dataclass
class CPQShadowHandCatchOver2UnderarmSafejointConfig(CPQTrainConfig):
    pass

@dataclass
class CPQShadowHandCatchOver2UnderarmSafefingerConfig(CPQTrainConfig):
    pass

@dataclass
class CPQFreightFrankaCloseDrawerConfig(CPQTrainConfig):
    pass

@dataclass
class CPQFreightFrankaPickAndPlaceConfig(CPQTrainConfig):
    pass


CPQ_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "SafetyAntMultiGoal1-v0": CPQSafetyAntMultiGoal1v0Config,
    "SafetyPointMultiGoal1-v0": CPQSafetyPointMultiGoal1v0Config,
    "SafetyAntMultiGoal2-v0": CPQSafetyAntMultiGoal2v0Config,
    "SafetyPointMultiGoal2-v0": CPQSafetyPointMultiGoal2v0Config,
    # safety_gymnasium
    "Safety2x4AntVelocity-v0": CPQSafety2x4AntVelocityv0Config,
    "Safety4x2AntVelocity-v0": CPQSafety4x2AntVelocityv0Config,
    "Safety2x3HalfCheetahVelocity-v0": CPQSafety2x3HalfCheetahVelocityv0Config,
    "Safety6x1HalfCheetahVelocity-v0": CPQSafety6x1HalfCheetahVelocityv0Config,
    "Safety2x3Walker2dVelocity-v0": CPQSafety2x3Walker2dVelocityv0Config,
    "Safety3x1HopperVelocity-v0": CPQSafety3x1HopperVelocityv0Config,
    "Safety2x1SwimmerVelocity-v0": CPQSafety2x1SwimmerVelocityv0Config,
    "Safety98HumanoidVelocity-v0": CPQSafety98HumanoidVelocityv0Config,
    # safe_isaac_gym
    "ShadowHandOver_Safe_joint": CPQShadowHandOverSafejointConfig,
    "ShadowHandOver_Safe_finger": CPQShadowHandOverSafefingerConfig,
    "ShadowHandCatchOver2Underarm_Safe_joint": CPQShadowHandCatchOver2UnderarmSafejointConfig,
    "ShadowHandCatchOver2Underarm_Safe_finger": CPQShadowHandCatchOver2UnderarmSafefingerConfig,
    "FreightFrankaCloseDrawer": CPQFreightFrankaCloseDrawerConfig,
    "FreightFrankaPickAndPlace": CPQFreightFrankaPickAndPlaceConfig,
}