import dataclasses


@dataclasses.dataclass
class Config:
    seed: int = 2  # seed of the experiment
    wandb_project_name: str = ""  # wandb project name
    wandb_tags: str = ""  # wandb tags
    log_wandb: bool = True  # if toggled, this experiment will be tracked with Weights and Biases

    principal: str = "LLM"

    """ Multiple principals """
    agent_lr: float = 3e-4  # the learning rate of the agent optimizer
    total_principal_steps: int = 10000  # the number of episode
    capture_video: bool = False  # whether to capture videos of the agent performances
    video_freq: int = 10  # capture video every how many episodes?
    principal_lr: float = 1e-4  # the learning rate of the principal optimizer
    eps_per_tax_rate: int = 5
    num_val_episodes: int = 3
    bandit_num_discretized_rates: int = 11
    principal_gets_aux: bool = True

    """ Epsilon Greedy """
    epsilon: float = 0.1

    """ USB """
    ucb_coef: float = 0.2

    """ Dual-RL only """
    principal_ent_coef: float = 0.2  # coefficient of entropy loss term for Dual-RL principal
    num_tax_annealment_episodes: int = 50  # number of episodes over which to linearly anneal maximum tax cap to 1
    initial_max_tax_rate: float = 1  # maximum tax cap to start annealing from
    dual_rl_principal_num_policy_updates_per_collection_update: int = 4
    principal_clip_coef: float = 0.2  # the surrogate clipping coefficient for the principal
    dual_rl_num_hidden_layers: int = 2
    dual_rl_use_running_mean: bool = False
    dual_rl_hidden_dim: int = 128

    """ AID only """
    aid_hidden_dimension: int = 256
    aid_sigmoid_shift: int = 0
    aid_num_hidden_layers: int = 2

    """ LLM only """
    temperature: float = 0.01
    llm_model: str = "gemini-1.5-flash"
    llm_prompt_style: str = ""
    llm_gets_aux: bool = False

    """------------------------------------------------------------"""

    saved_core_path: str = ""
    saved_heads_path: str = ""
    env_name: str = "commons_harvest__open"
    log_locally: bool = False  # if toggled, log messages will be printed to stderr
    save_model: bool = False  # whether to save model parameters
    save_model_freq: int = 50  # save model parameters every how many episodes?

    """ Usually only changed for pretraining nets. """
    freeze_agent_net_core: bool = True  # whether to freeze the main body of agent nets
    freeze_whole_agent_net: bool = False
    num_parallel_games: int = 1  # the number of parallel game environments
    episode_length: int = 1000  # the number of steps in an episode
    sampling_horizon: int = 1000  # the number of timesteps between policy update iterations
    minibatch_size: int = 128  # size of minibatches when training policy network

    """ Not currently being changed. """
    reset_agent_nets: bool = True  # whether to reset agent nets to random initialization
    adam_eps: float = 1e-5  # eps value for all adam optimizers
    gamma: float = 0.998  # the discount factor gamma
    gae_lambda: float = 0.98  # the lambda for the general advantage estimation
    num_policy_updates_per_collection_update: int = 4  # the K epochs to update the policy
    norm_adv: bool = True  # Toggles advantages normalization
    clip_coef: float = 0.2  # the surrogate clipping coefficient
    value_clip_coef: float = 0.2  # value estimate clipping coefficient
    clip_vloss: bool = True  # Toggles whether or not to use a clipped loss for the value function, as per the paper.

    agent_ent_coef: float = 0.025  # coefficient of entropy loss term for agents
    vf_coef: float = 0.5  # coefficient of the value function
    principal_vf_coef: float = 0.5  # coefficient of the value function
    max_grad_norm: float = 0.5  # maximum norm for gradient clipping shared by agents and principals
    target_kl: float = None  # the target KL divergence threshold

    """ Likely never changed. """
    log_file: str = None  # the file to log to relative to the Globals.LOG_DIR
    cuda: bool = True  # if toggled, cuda will be enabled by default
    wandb_entity: str = "lad"  # entity (team) of wandb project
    flush_interval: int = 1  # ? what is this?
    algorithm: str = "ppo"
