"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class BCQLTrainConfig:
    # self-designed args
    

    # wandb params
    project: str = "MOSDB-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "BCQL"
    suffix: Optional[str] = ""
    log_root_dir: Optional[str] = "../benchmark"
    verbose: bool = True

    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float, float, float], ...] = None
    epsilon: float = None
    density: float = 1.0

    # training params
    task: str = "FreightFrankaCloseDrawer"
    dataset: str = None
    seed: int = 0
    device: str = "cuda:0"
    threads: int = 4
    reward_scale: float = 0.1
    cost_scale: float = 1
    actor_lr: float = 0.001
    critic_lr: float = 0.001
    vae_lr: float = 0.001
    phi: float = 0.05
    lmbda: float = 0.75
    beta: float = 0.5
    cost_limit: int = 25
    episode_len: int = 1000
    batch_size: int = 64
    num_workers: int = 8
    update_steps: int = 100000
    train_epoch: int = 1
    centralized_training: bool = True
    
    # model params
    a_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    c_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    vae_hidden_sizes: int = 400
    sample_action_num: int = 10
    gamma: float = 0.99
    tau: float = 0.005
    num_q: int = 2
    num_qc: int = 2
    PID: List[float] = field(default=[0.1, 0.003, 0.001], is_mutable=True)

    # evaluation params
    eval_episodes: int = 1
    eval_every: int = 2500
    save_model: bool = False


@dataclass
class BCQLSafetyAntMultiGoal1v0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafetyPointMultiGoal1v0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafetyAntMultiGoal2v0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafetyPointMultiGoal2v0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety2x4AntVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety4x2AntVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety2x3HalfCheetahVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety6x1HalfCheetahVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety2x3Walker2dVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety3x1HopperVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety2x1SwimmerVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLSafety98HumanoidVelocityv0Config(BCQLTrainConfig):
    pass

@dataclass
class BCQLShadowHandOverSafejointConfig(BCQLTrainConfig):
    pass

@dataclass
class BCQLShadowHandOverSafefingerConfig(BCQLTrainConfig):
    pass

@dataclass
class BCQLShadowHandCatchOver2UnderarmSafejointConfig(BCQLTrainConfig):
    pass

@dataclass
class BCQLShadowHandCatchOver2UnderarmSafefingerConfig(BCQLTrainConfig):
    pass

@dataclass
class BCQLFreightFrankaCloseDrawerConfig(BCQLTrainConfig):
    pass

@dataclass
class BCQLFreightFrankaPickAndPlaceConfig(BCQLTrainConfig):
    pass


BCQL_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "SafetyAntMultiGoal1-v0": BCQLSafetyAntMultiGoal1v0Config,
    "SafetyPointMultiGoal1-v0": BCQLSafetyPointMultiGoal1v0Config,
    "SafetyAntMultiGoal2-v0": BCQLSafetyAntMultiGoal2v0Config,
    "SafetyPointMultiGoal2-v0": BCQLSafetyPointMultiGoal2v0Config,
    # safety_gymnasium
    "Safety2x4AntVelocity-v0": BCQLSafety2x4AntVelocityv0Config,
    "Safety4x2AntVelocity-v0": BCQLSafety4x2AntVelocityv0Config,
    "Safety2x3HalfCheetahVelocity-v0": BCQLSafety2x3HalfCheetahVelocityv0Config,
    "Safety6x1HalfCheetahVelocity-v0": BCQLSafety6x1HalfCheetahVelocityv0Config,
    "Safety2x3Walker2dVelocity-v0": BCQLSafety2x3Walker2dVelocityv0Config,
    "Safety3x1HopperVelocity-v0": BCQLSafety3x1HopperVelocityv0Config,
    "Safety2x1SwimmerVelocity-v0": BCQLSafety2x1SwimmerVelocityv0Config,
    "Safety98HumanoidVelocity-v0": BCQLSafety98HumanoidVelocityv0Config,
    # safe_isaac_gym
    "ShadowHandOver_Safe_joint": BCQLShadowHandOverSafejointConfig,
    "ShadowHandOver_Safe_finger": BCQLShadowHandOverSafefingerConfig,
    "ShadowHandCatchOver2Underarm_Safe_joint": BCQLShadowHandCatchOver2UnderarmSafejointConfig,
    "ShadowHandCatchOver2Underarm_Safe_finger": BCQLShadowHandCatchOver2UnderarmSafefingerConfig,
    "FreightFrankaCloseDrawer": BCQLFreightFrankaCloseDrawerConfig,
    "FreightFrankaPickAndPlace": BCQLFreightFrankaPickAndPlaceConfig,
}