from typing import Literal, Optional, Union

from flax.struct import dataclass

from jaxgcrl.agents import CRL, HCRL, PPO, SAC, TD3

from .env import legal_envs

# agent configurations
AgentConfig = Union[CRL, HCRL, PPO, SAC, TD3]


@dataclass
class RunConfig:
    """General run configs
    Args:
        seed: random seed
        env: the name of the environment to use
        total_env_steps: the total number of environment steps to run
        episode_length: the maximum length of an episode
            NOTE: `num_envs * (episode_length - 1)` must be divisible by
            `batch_size` due to the way data is stored in replay buffer.
        num_envs: the number of parallel environments to use for rollouts
            NOTE: `num_envs` must be divisible by the total number of chips since each
            chip gets `num_envs // total_number_of_chips` environments to roll out
            NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
            data generated by `num_envs` parallel envs gets used for gradient
            updates over `num_minibatches` of data, where each minibatch has a
            leading dimension of `batch_size`
        max_devices_per_host: maximum number of chips to use per host process
        num_eval_envs: the number of envs to use for evluation. Each env will run 1
        episode, and all envs run in parallel during eval.
        action_repeat: the number of timesteps to repeat an action
        exp_name: the name of the experiment for logging
        num_evals: the number of evals to run during the entire training run.
            Increasing the number of evals increases total training time
    """

    # environment to use
    env: Literal[legal_envs]

    # total number of environment steps to run
    total_env_steps: int = 50_000_000

    # maximum length of an episode
    episode_length: int = 1001

    # environment to use for evaluation
    eval_env: Optional[Literal[legal_envs]] = None

    # number of envs to run in parallel during training
    num_envs: int = 256

    # number of envs to run in parallel during evaluation
    num_eval_envs: int = 256

    action_repeat: int = 1

    # total number of evals during training
    num_evals: int = 200

    seed: int = 0
    backend: Optional[Literal["mjx", "spring", "positional", "generalized"]] = None

    # wandb logging
    exp_name: str = "run"
    log_wandb: bool = True
    wandb_project_name: str = "jaxgcrl"
    wandb_group: str = "."

    # online or offline
    wandb_mode: Literal["online", "offline"] = "online"

    # render frequency
    visualization_interval: int = 5

    vis_length: int = 1000
    checkpoint_logdir: Optional[str] = None
    max_devices_per_host: int = 1
    cuda: bool = True


@dataclass
class Config:
    # agent type
    agent: AgentConfig
    # run config
    run: RunConfig
