"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class COptiDICETrainConfig:
    # self-designed args
    

    # wandb params
    project: str = "MOSDB-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "COptiDICE"
    suffix: Optional[str] = ""
    log_root_dir: Optional[str] = "../benchmark"
    verbose: bool = True

    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float, float, float], ...] = None
    epsilon: float = None
    density: float = 1.0

    # training params
    task: str = "FreightFrankaCloseDrawer"
    dataset: str = None
    seed: int = 0
    device: str = "cuda:0"
    threads: int = 4
    reward_scale: float = 0.1
    cost_scale: float = 1
    actor_lr: float = 0.0001
    critic_lr: float = 0.0001
    scalar_lr: float = 0.0001
    cost_limit: int = 25
    episode_len: int = 1000
    batch_size: int = 64
    num_workers: int = 8
    update_steps: int = 100000
    train_epoch: int = 1
    centralized_training: bool = True

    # model params
    a_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    c_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    alpha: float = 0.5
    gamma: float = 0.99
    cost_ub_epsilon: float = 0.01
    f_type: str = "softchi"
    num_nu: int = 2
    num_chi: int = 2

    # evaluation params
    eval_episodes: int = 1
    eval_every: int = 2500
    save_model: bool = False


@dataclass
class COptiDICSafetyAntMultiGoal1v0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafetyPointMultiGoal1v0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafetyAntMultiGoal2v0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafetyPointMultiGoal2v0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety2x4AntVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety4x2AntVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety2x3HalfCheetahVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety6x1HalfCheetahVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety2x3Walker2dVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety3x1HopperVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety2x1SwimmerVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICSafety98HumanoidVelocityv0Config(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICShadowHandOverSafejointConfig(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICShadowHandOverSafefingerConfig(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICShadowHandCatchOver2UnderarmSafejointConfig(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICShadowHandCatchOver2UnderarmSafefingerConfig(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICFreightFrankaCloseDrawerConfig(COptiDICETrainConfig):
    pass

@dataclass
class COptiDICFreightFrankaPickAndPlaceConfig(COptiDICETrainConfig):
    pass


COptiDICE_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "SafetyAntMultiGoal1-v0": COptiDICSafetyAntMultiGoal1v0Config,
    "SafetyPointMultiGoal1-v0": COptiDICSafetyPointMultiGoal1v0Config,
    "SafetyAntMultiGoal2-v0": COptiDICSafetyAntMultiGoal2v0Config,
    "SafetyPointMultiGoal2-v0": COptiDICSafetyPointMultiGoal2v0Config,
    # safety_gymnasium
    "Safety2x4AntVelocity-v0": COptiDICSafety2x4AntVelocityv0Config,
    "Safety4x2AntVelocity-v0": COptiDICSafety4x2AntVelocityv0Config,
    "Safety2x3HalfCheetahVelocity-v0": COptiDICSafety2x3HalfCheetahVelocityv0Config,
    "Safety6x1HalfCheetahVelocity-v0": COptiDICSafety6x1HalfCheetahVelocityv0Config,
    "Safety2x3Walker2dVelocity-v0": COptiDICSafety2x3Walker2dVelocityv0Config,
    "Safety3x1HopperVelocity-v0": COptiDICSafety3x1HopperVelocityv0Config,
    "Safety2x1SwimmerVelocity-v0": COptiDICSafety2x1SwimmerVelocityv0Config,
    "Safety98HumanoidVelocity-v0": COptiDICSafety98HumanoidVelocityv0Config,
    # safe_isaac_gym
    "ShadowHandOver_Safe_joint": COptiDICShadowHandOverSafejointConfig,
    "ShadowHandOver_Safe_finger": COptiDICShadowHandOverSafefingerConfig,
    "ShadowHandCatchOver2Underarm_Safe_joint": COptiDICShadowHandCatchOver2UnderarmSafejointConfig,
    "ShadowHandCatchOver2Underarm_Safe_finger": COptiDICShadowHandCatchOver2UnderarmSafefingerConfig,
    "FreightFrankaCloseDrawer": COptiDICFreightFrankaCloseDrawerConfig,
    "FreightFrankaPickAndPlace": COptiDICFreightFrankaPickAndPlaceConfig,
}