from dataclasses import dataclass


SMACV1_ENV_NAMES = ["2c_vs_64zg", "5m_vs_6m", "6h_vs_8z", "corridor"]
SMACV2_ENV_NAMES = [f"{map_name}_{map_mode}" for map_name in ["protoss", "terran", "zerg"] for map_mode in ["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]]
MAMUJOCO_ENV_NAMES = ["Hopper-v2", "Ant-v2", "HalfCheetah-v2"]


@dataclass
class Args:
    algo: str = "misodice"
    env_name: str = "protoss_5_vs_5"
    device: str = "cuda"
    exsize: int = 200
    seed: int = 0
    hidden_size: int = 256
    critic_lr: float = 3e-4
    actor_lr: float = 3e-4
    grad_reg_coeffs: tuple = (0.1, 0.0001)
    gamma: float = 0.99
    alpha: float = 0.05
    use_last_layer_bias_cost: bool = False
    use_last_layer_bias_critic: bool = False
    kernel_initializer: str = "he_normal"
    n_minibatches: int = 512
    use_llm: bool = False
    n_epochs: int = 100