from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class WSACTrainConfig:
    # wandb params
    project: str = "PXY-WSAC"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "WSAC"
    suffix: Optional[str] = ""
    logdir: Optional[str] = "logs"
    verbose: bool = True
    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float, float, float], ...] = None
    epsilon: float = None
    density: float = 1.0
    # training params
    task: str = "OfflineCarCircle-v0"
    dataset: str = None
    seed: int = 0
    device: str = "cpu"
    threads: int = 4
    reward_scale: float = 0.1
    cost_scale: float = 1
    actor_lr: float = 0.0001
    critic_lr: float = 0.0003
    scalar_lr: float = 0.0001
    cost_limit: int = 40
    episode_len: int = 300
    batch_size: int = 512
    update_steps: int = 30_000
    num_workers: int = 8
    num_q: int = 1
    num_qc: int = 1
    qc_scalar: float = 1.5
    # model params
    a_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    c_hidden_sizes: List[float] = field(default=[256, 256], is_mutable=True)
    alpha: float = 0.5
    gamma: float = 0.99
    cost_ub_epsilon: float = 0.01
    f_type: str = "softchi"
    num_nu: int = 2
    num_chi: int = 2
    tau: float = 0.005
    # evaluation params
    eval_episodes: int = 20
    eval_every: int = 500
    beta_r : float = 10
    beta_c : float = 30
    lambda_: float = 2.0
    lambda_max: float = 1000
    act_times: int = 10
    qc_ub: float = 30

@dataclass
class WSACCarCircleConfig(WSACTrainConfig):
    pass


@dataclass
class WSACAntRunConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineAntRun-v0"
    episode_len: int = 200


@dataclass
class WSACDroneRunConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineDroneRun-v0"
    episode_len: int = 200


@dataclass
class WSACDroneCircleConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineDroneCircle-v0"
    episode_len: int = 300


@dataclass
class WSACCarRunConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineCarRun-v0"
    episode_len: int = 200


@dataclass
class WSACAntCircleConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineAntCircle-v0"
    episode_len: int = 500


@dataclass
class WSACBallRunConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineBallRun-v0"
    episode_len: int = 100


@dataclass
class WSACBallCircleConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineBallCircle-v0"
    episode_len: int = 200


@dataclass
class WSACCarButton1Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarButton1Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACCarButton2Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarButton2Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACCarCircle1Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarCircle1Gymnasium-v0"
    episode_len: int = 500


@dataclass
class WSACCarCircle2Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarCircle2Gymnasium-v0"
    episode_len: int = 500


@dataclass
class WSACCarGoal1Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarGoal1Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACCarGoal2Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarGoal2Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACCarPush1Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarPush1Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACCarPush2Config(WSACTrainConfig):
    # training params
    task: str = "OfflineCarPush2Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACPointButton1Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointButton1Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACPointButton2Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointButton2Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACPointCircle1Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointCircle1Gymnasium-v0"
    episode_len: int = 500


@dataclass
class WSACPointCircle2Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointCircle2Gymnasium-v0"
    episode_len: int = 500


@dataclass
class WSACPointGoal1Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointGoal1Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACPointGoal2Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointGoal2Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACPointPush1Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointPush1Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACPointPush2Config(WSACTrainConfig):
    # training params
    task: str = "OfflinePointPush2Gymnasium-v0"
    episode_len: int = 1000


@dataclass
class WSACAntVelocityConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineAntVelocityGymnasium-v1"
    episode_len: int = 1000


@dataclass
class WSACHalfCheetahVelocityConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineHalfCheetahVelocityGymnasium-v1"
    episode_len: int = 1000


@dataclass
class WSACHopperVelocityConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineHopperVelocityGymnasium-v1"
    episode_len: int = 1000


@dataclass
class WSACSwimmerVelocityConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineSwimmerVelocityGymnasium-v1"
    episode_len: int = 1000


@dataclass
class WSACWalker2dVelocityConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineWalker2dVelocityGymnasium-v1"
    episode_len: int = 1000


@dataclass
class WSACEasySparseConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-easysparse-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACEasyMeanConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-easymean-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACEasyDenseConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-easydense-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACMediumSparseConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-mediumsparse-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACMediumMeanConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-mediummean-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACMediumDenseConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-mediumdense-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACHardSparseConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-hardsparse-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACHardMeanConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-hardmean-v0"
    episode_len: int = 1000
    update_steps: int = 200_000


@dataclass
class WSACHardDenseConfig(WSACTrainConfig):
    # training params
    task: str = "OfflineMetadrive-harddense-v0"
    episode_len: int = 1000
    update_steps: int = 200_000

WSAC_DEFAULT_CONFIG = {
    # bullet_safety_gym
    "OfflineCarCircle-v0": WSACCarCircleConfig,
    "OfflineAntRun-v0": WSACAntRunConfig,
    "OfflineDroneRun-v0": WSACDroneRunConfig,
    "OfflineDroneCircle-v0": WSACDroneCircleConfig,
    "OfflineCarRun-v0": WSACCarRunConfig,
    "OfflineAntCircle-v0": WSACAntCircleConfig,
    "OfflineBallCircle-v0": WSACBallCircleConfig,
    "OfflineBallRun-v0": WSACBallRunConfig,
    # safety_gymnasium
    "OfflineCarButton1Gymnasium-v0": WSACCarButton1Config,
    "OfflineCarButton2Gymnasium-v0": WSACCarButton2Config,
    "OfflineCarCircle1Gymnasium-v0": WSACCarCircle1Config,
    "OfflineCarCircle2Gymnasium-v0": WSACCarCircle2Config,
    "OfflineCarGoal1Gymnasium-v0": WSACCarGoal1Config,
    "OfflineCarGoal2Gymnasium-v0": WSACCarGoal2Config,
    "OfflineCarPush1Gymnasium-v0": WSACCarPush1Config,
    "OfflineCarPush2Gymnasium-v0": WSACCarPush2Config,
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": WSACPointButton1Config,
    "OfflinePointButton2Gymnasium-v0": WSACPointButton2Config,
    "OfflinePointCircle1Gymnasium-v0": WSACPointCircle1Config,
    "OfflinePointCircle2Gymnasium-v0": WSACPointCircle2Config,
    "OfflinePointGoal1Gymnasium-v0": WSACPointGoal1Config,
    "OfflinePointGoal2Gymnasium-v0": WSACPointGoal2Config,
    "OfflinePointPush1Gymnasium-v0": WSACPointPush1Config,
    "OfflinePointPush2Gymnasium-v0": WSACPointPush2Config,
    # safety_gymnasium: velocity
    "OfflineAntVelocityGymnasium-v1": WSACAntVelocityConfig,
    "OfflineHalfCheetahVelocityGymnasium-v1": WSACHalfCheetahVelocityConfig,
    "OfflineHopperVelocityGymnasium-v1": WSACHopperVelocityConfig,
    "OfflineSwimmerVelocityGymnasium-v1": WSACSwimmerVelocityConfig,
    "OfflineWalker2dVelocityGymnasium-v1": WSACWalker2dVelocityConfig,
    # safe_metadrive
    "OfflineMetadrive-easysparse-v0": WSACEasySparseConfig,
    "OfflineMetadrive-easymean-v0": WSACEasyMeanConfig,
    "OfflineMetadrive-easydense-v0": WSACEasyDenseConfig,
    "OfflineMetadrive-mediumsparse-v0": WSACMediumSparseConfig,
    "OfflineMetadrive-mediummean-v0": WSACMediumMeanConfig,
    "OfflineMetadrive-mediumdense-v0": WSACMediumDenseConfig,
    "OfflineMetadrive-hardsparse-v0": WSACHardSparseConfig,
    "OfflineMetadrive-hardmean-v0": WSACHardMeanConfig,
    "OfflineMetadrive-harddense-v0": WSACHardDenseConfig
}
