import os
from dataclasses import dataclass


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 42
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = True
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "neurotrails_atari"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances in eval_env (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""

    # Algorithm specific arguments
    env_id: str = "BreakoutNoFrameskip-v4"
    """the id of the environment"""
    total_timesteps: int = 10_000_000
    """total timesteps of the experiments"""
    learning_rate: float = 1e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    buffer_size: int = 1_000_000
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 1.0
    """the target network update rate"""
    target_network_frequency: int = 1000
    """the timesteps it takes to update the target network"""
    batch_size: int = 32
    """the batch size of sample from the reply memory"""
    start_e: float = 1
    """the starting epsilon for exploration"""
    end_e: float = 0.01
    """the ending epsilon for exploration"""
    exploration_fraction: float = 0.10
    """the fraction of `total-timesteps` it takes from start-e to go end-e"""
    learning_starts: int = 80_000
    """timestep to start learning"""
    train_frequency: int = 4
    """the frequency of training"""

    ## added by neurotrails
    console_log: str = "default"
    """the console log file on server, so we can find it on wandb in the args"""
    eval_freq: int = 100_000
    """the frequency of evaluation in timesteps"""
    hidden_dim: int = 512
    """the hidden dimension of the Q network(s)"""
    linear_layers: int = 0
    """the number of additional hidden linear layers in the Q network(s)"""
    num_ensemble: int = 1
    """the ensemble size for the Q network(s). When layers_in_head > 0, num_ensemble is the number of heads."""
    blocks_in_head: int = -1
    """the number of blocks (layer+ReLU) in each head for the TreeQNetwork. -1 means a full ensemble (EnsembleQNetwork). Default architecture has 3+2+linear_layers blocks."""
    joint_training: bool = False
    """whether to train the ensemble jointly (the mean Q values) or independently (each Q compared with target separately)"""
    joint_sampling: bool = True
    """whether to sample actions jointly (by mean Q values) or independently (by each member, could need more envs or more timesteps)"""

    ### sparsity
    density: float = 1.0
    """The density of the sparse network. Final density if density_decay != 'constant'."""
    dst_update_freq: int = 1_000
    """Number of train steps between mask updates."""
    growth: str = "random"
    """Growth mode. Options: 'gradient', 'random'."""
    prune: str = "magnitude"
    """Pruning mode. Options: 'magnitude', 'magnitude_soft'."""
    reinit: str = "no"
    """Weight reinit mode. Options: 'no', 'zero'."""
    mix: float = 0.0                # to remove
    """Mix ...."""
    redistribution: str = "none"    # to remove
    """Redistribution mode. Options: 'momentum', 'magnitude', 'nonzeros', 'none'."""
    prune_rate: float = 0.1
    """Pruning rate (fraction removed per update)."""
    prune_rate_decay: str = "cosine"
    """Prune-rate schedule. Options: 'cosine', 'linear', 'constant'."""
    density_decay: str = "constant"
    """Density schedule. Options: 'constant', 'linear', 'cosine'."""
    initial_density: float = 0.999
    """Initial density when density_decay != 'constant'."""
    fix: bool = False
    """Fixed sparse topology throughout (static sparse training)."""
    sparse_init: str = "Multi_Output"
    """Sparse init scheme (e.g., 'Multi_Output', 'ER', 'ERK', 'Lottery')."""
    temperature_decay: str = "constant"
    """Temperature schedule. Options: 'constant', 'linear'."""
    temperature: float = 3.0
    """Final temperature for soft sampling."""
    init_temperature: float = 1.0
    """Initial temperature when temperature_decay != 'constant'."""
