"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple


@dataclass
class CDTTrainConfig:
    # self-designed args
    seq_len: int = 10

    # wandb params
    project: str = "MOSDB-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "CDT"
    suffix: Optional[str] = ""
    log_root_dir: Optional[str] = "../benchmark"
    verbose: bool = True

    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float], ...] = None
    epsilon: float = None
    density: float = 1

    # model params
    embedding_dim: int = 128
    num_layers: int = 3
    num_heads: int = 8
    action_head_layers: int = 1
    attention_dropout: float = 0.1
    residual_dropout: float = 0.1
    embedding_dropout: float = 0.1
    time_emb: bool = True

    # training params
    task: str = "FreightFrankaCloseDrawer"
    dataset: str = None
    learning_rate: float = 1e-4
    betas: Tuple[float, float] = (0.9, 0.999)
    weight_decay: float = 1e-4
    clip_grad: Optional[float] = 0.25
    lr_warmup_steps: int = 500
    reward_scale: float = 0.1
    cost_scale: float = 1
    batch_size: int = 64
    num_workers: int = 8
    update_steps: int = 100000
    train_epoch: int = 1
    centralized_training: bool = True

    # evaluation params
    target_returns: Tuple[Tuple[float, ...], ...] = ((5, 25,))  # reward, cost
    eval_episodes: int = 1
    eval_every: int = 2500
    cost_limit: int = 25
    save_model: bool = False

    # general params
    seed: int = 0
    device: str = "cuda:0"
    threads: int = 6
    # augmentation param
    pf_sample: bool = False
    beta: float = 1.0
    augment_percent: float = 0.2
    deg: int = 1
    # maximum absolute value of reward for the augmented trajs
    max_reward: float = 600
    # minimum reward above the PF curve
    min_reward: float = 1
    # the max drecrease of ret between the associated traj
    # w.r.t the nearest pf traj
    max_rew_decrease: float = 100.0
    # model mode params
    use_rew: bool = True
    use_cost: bool = True
    cost_transform: bool = False
    cost_prefix: bool = False
    add_cost_feat: bool = False
    mul_cost_feat: bool = False
    cat_cost_feat: bool = False
    loss_cost_weight: float = 0.02
    loss_state_weight: float = 0
    cost_reverse: bool = False
    # pf only mode param
    pf_only: bool = False
    rmin: float = 300
    cost_bins: int = 3
    npb: int = 5
    cost_sample: bool = False
    linear: bool = True  # linear or inverse
    start_sampling: bool = False
    prob: float = 0.2
    stochastic: bool = True
    init_temperature: float = 0.1
    no_entropy: bool = False
    # random augmentation
    random_aug: float = 0
    aug_rmin: float = 400
    aug_rmax: float = 500
    aug_cmin: float = -2
    aug_cmax: float = 25
    cgap: float = 5
    rstd: float = 1
    cstd: float = 0.2
    max_npb: int = 10
    min_npb: int = 2
    ct_max: float = 70


@dataclass
class CDTSafetyAntMultiGoal1v0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafetyPointMultiGoal1v0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafetyAntMultiGoal2v0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafetyPointMultiGoal2v0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety2x4AntVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety4x2AntVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety2x3HalfCheetahVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety6x1HalfCheetahVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety2x3Walker2dVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety3x1HopperVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety2x1SwimmerVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTSafety98HumanoidVelocityv0Config(CDTTrainConfig):
    pass

@dataclass
class CDTShadowHandOverSafejointConfig(CDTTrainConfig):
    pass
    # seq_len: int = 8

@dataclass
class CDTShadowHandOverSafefingerConfig(CDTTrainConfig):
    pass
    # seq_len: int = 8

@dataclass
class CDTShadowHandCatchOver2UnderarmSafejointConfig(CDTTrainConfig):
    pass
    # seq_len: int = 8

@dataclass
class CDTShadowHandCatchOver2UnderarmSafefingerConfig(CDTTrainConfig):
    pass
    # seq_len: int = 8

@dataclass
class CDTFreightFrankaCloseDrawerConfig(CDTTrainConfig):
    pass
    # seq_len: int = 8

@dataclass
class CDTFreightFrankaPickAndPlaceConfig(CDTTrainConfig):
    pass
    # seq_len: int = 8


CDT_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "SafetyAntMultiGoal1-v0": CDTSafetyAntMultiGoal1v0Config,
    "SafetyPointMultiGoal1-v0": CDTSafetyPointMultiGoal1v0Config,
    "SafetyAntMultiGoal2-v0": CDTSafetyAntMultiGoal2v0Config,
    "SafetyPointMultiGoal2-v0": CDTSafetyPointMultiGoal2v0Config,
    # safety_gymnasium
    "Safety2x4AntVelocity-v0": CDTSafety2x4AntVelocityv0Config,
    "Safety4x2AntVelocity-v0": CDTSafety4x2AntVelocityv0Config,
    "Safety2x3HalfCheetahVelocity-v0": CDTSafety2x3HalfCheetahVelocityv0Config,
    "Safety6x1HalfCheetahVelocity-v0": CDTSafety6x1HalfCheetahVelocityv0Config,
    "Safety2x3Walker2dVelocity-v0": CDTSafety2x3Walker2dVelocityv0Config,
    "Safety3x1HopperVelocity-v0": CDTSafety3x1HopperVelocityv0Config,
    "Safety2x1SwimmerVelocity-v0": CDTSafety2x1SwimmerVelocityv0Config,
    "Safety98HumanoidVelocity-v0": CDTSafety98HumanoidVelocityv0Config,
    # safe_isaac_gym
    "ShadowHandOver_Safe_joint": CDTShadowHandOverSafejointConfig,
    "ShadowHandOver_Safe_finger": CDTShadowHandOverSafefingerConfig,
    "ShadowHandCatchOver2Underarm_Safe_joint": CDTShadowHandCatchOver2UnderarmSafejointConfig,
    "ShadowHandCatchOver2Underarm_Safe_finger": CDTShadowHandCatchOver2UnderarmSafefingerConfig,
    "FreightFrankaCloseDrawer": CDTFreightFrankaCloseDrawerConfig,
    "FreightFrankaPickAndPlace": CDTFreightFrankaPickAndPlaceConfig,
}
