from torch import nn
from stable_baselines3.common.utils import get_schedule_fn
from typing import Any, ClassVar, Optional, TypeVar, Union
from stable_baselines3.common.noise import NormalActionNoise
import numpy as np
env_timestep = dict({
    "ant" : 1000000,
    "hopper" : 1000000,
    "humanoid" : 1000000,
    "lunar" : 500000,
    "pendulum" : 100000,
    "reacher" : 500000,
    "walker" : 1000000,
})
env_args = dict({
    "ant" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 1_000_000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.99,
        train_freq=(1, "step"),
        gradient_steps=1,
    ),
    "hopper" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 1_000_000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.99,
        train_freq=(1, "step"),
        gradient_steps=1,
    ),
    "humanoid" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 1_000_000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.99,
        train_freq=(1, "step"),
        gradient_steps=1,
    ),
    "lunar" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 200000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.98,
        train_freq=(1, "step"),
        gradient_steps=1,
    ),
    "pendulum" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 200000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.98,
        train_freq=(1, "step"),
        gradient_steps=1,
    ),
    "reacher" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 1_000_000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.99,
        train_freq=(1, "step"),
        gradient_steps=1,
    ),
    "walker" : dict(
        n_envs= 1,
        learning_rate=1e-3,
        buffer_size = 1_000_000,
        learning_starts=10000,
        batch_size=256,
        tau=0.005,
        gamma=0.99,
        train_freq=(1, "step"),
        gradient_steps=1,
    )
})
alg_args = dict({
    "vanilla" : dict(
        ant = dict(),
        hopper = dict(),
        humanoid = dict(),
        lunar = dict(),
        reacher = dict(),
        pendulum = dict(),
        walker = dict(),
    ),
    "caps" : dict(
        ant = dict(
            caps_sigma = 0.2,
            caps_lamT = 0.1,
            caps_lamS = 0.5,),
        hopper = dict(
            caps_sigma = 0.2,
            caps_lamT = 0.1,
            caps_lamS = 0.5,),
        humanoid = dict(
            caps_sigma = 0.2,
            caps_lamT = 0.1,
            caps_lamS = 0.5,),
        lunar = dict(
            caps_sigma = 0.2,
            caps_lamT = 0.1,
            caps_lamS = 0.5,),
        reacher = dict(
            caps_sigma = 0.2,
            caps_lamT = 0.1,
            caps_lamS = 0.5,),
        pendulum = dict(
            caps_sigma = 0.2,
            caps_lamT = 1.0,
            caps_lamS = 5.0,),
        walker = dict(
            caps_sigma = 0.2,
            caps_lamT = 0.1,
            caps_lamS = 0.5,),
    ),
    "asap" : dict(
        ant = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.3,
            asap_lamT = 0.05),
        hopper = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.3,
            asap_lamT = 0.07),
        humanoid = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.3,
            asap_lamT = 0.05),
        lunar = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.03,
            asap_lamT = 0.005),
        reacher = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.1,
            asap_lamT = 0.1),
        pendulum = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.03,
            asap_lamT = 0.005),
        walker = dict(
            asap_lamP = 2.0,
            asap_lamS = 0.3,
            asap_lamT = 0.05),
    ),
    "grad" : dict(
        ant = dict(
            grad_lamT = 1.0
        ),
        hopper = dict(
            grad_lamT = 1.0
        ),
        humanoid = dict(
            grad_lamT = 1.0
        ),
        lunar = dict(
            grad_lamT = 1.0
        ),
        reacher = dict(
            grad_lamT = 1.0
        ),
        pendulum = dict(
            grad_lamT = 1.0
        ),
        walker = dict(
            grad_lamT = 1.0
        ),
    ),
    "pave" : dict(
        ant = dict(
            grad_lamT = 1.0
        ),
        hopper = dict(
            grad_lamT = 1.0
        ),
        humanoid = dict(
            grad_lamT = 1.0
        ),
        lunar = dict(
            grad_lamT = 1.0
        ),
        reacher = dict(
            grad_lamT = 1.0
        ),
        pendulum = dict(
            grad_lamT = 1.0
        ),
        walker = dict(
            grad_lamT = 1.0
        ),
    ),
})