from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
import numpy as np
import random
import torch
import os
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
import torch.optim as optim
from custom_policies.custom_ddqn import CustomDDQN

atari_5 = [
    "QbertNoFrameskip-v4",
    "BattleZoneNoFrameskip-v4",
    "DoubleDunkNoFrameskip-v4",
    "NameThisGameNoFrameskip-v4",
    "PhoenixNoFrameskip-v4",
]

def seed_everything(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_environment_specific_settings(environment_name, n_envs:int = 1, seed:int = 0):

    # General settings & StableBaselines3 default parameters - Parameters are overwritten subsequently if a given environment requires a different specification
    model_class = CustomDDQN
    env = make_vec_env(environment_name, n_envs = n_envs, seed = seed)
    eval_env = make_vec_env(environment_name, n_envs = n_envs, seed = seed)
    dqn_policy = "MlpPolicy"
    reward_threshold = gym.spec(environment_name).reward_threshold
    num_evals = 100
    callback_on_new_best = StopTrainingOnRewardThreshold(reward_threshold = reward_threshold, verbose = 1)
    learning_rate = 0.0001
    learning_starts = 100
    n_eval_episodes = 5
    exploration_initial_eps = 1.
    exploration_fraction = .1
    exploration_final_eps = .05
    buffer_size = 1_000_000
    batch_size = 32
    policy_kwargs = None
    gradient_steps = 1
    max_grad_norm = 10.
    target_update_interval = 10_000
    train_freq = 4
    progress_bar = True
    gamma = 0.99

    if environment_name == "CartPole-v1":

        num_timesteps = 50000
        learning_rate = 2.3e-3
        learning_starts = 1000
        batch_size = 64
        buffer_size = 100000
        learning_starts = 1000
        target_update_interval = 10
        train_freq = 256
        gradient_steps = 128
        exploration_fraction = 0.16
        exploration_final_eps = 0.04
        policy_kwargs = dict(net_arch = [256, 256])
        num_evals = 100

    elif environment_name == "LunarLander-v2":

        num_timesteps = 1e5
        learning_rate = 6.3e-4
        learning_starts = 1000
        batch_size = 128
        buffer_size = 50000
        learning_starts = 1000
        target_update_interval = 250
        train_freq = 4
        gradient_steps = -1
        exploration_fraction = 0.12
        exploration_final_eps = 0.1
        policy_kwargs = dict(net_arch = [256, 256])
        num_evals = 100

    elif environment_name == "Acrobot-v1":

        num_timesteps = 1e5
        learning_rate = 6.3e-4
        learning_starts = 1000
        batch_size = 128
        buffer_size = 50000
        learning_starts = 1000
        target_update_interval = 250
        train_freq = 4
        gradient_steps = -1
        exploration_fraction = 0.12
        exploration_final_eps: 0.1
        policy_kwargs = dict(net_arch = [256, 256])
        num_evals = 100

    elif environment_name in atari_5:

        env = make_atari_env(environment_name, n_envs = n_envs, seed = seed)
        env = VecFrameStack(env, n_stack = 4)
        env = VecTransposeImage(env)

        eval_env = make_atari_env(environment_name, n_envs = n_envs, seed = seed)
        eval_env = VecFrameStack(eval_env, n_stack = 4)
        eval_env = VecTransposeImage(eval_env)

        dqn_policy = "CnnPolicy"
        num_timesteps = 50_000_000
        reward_threshold = np.inf
        num_evals = 200
        learning_starts = 50_000
        n_eval_episodes = 1
        callback_on_new_best = None
        learning_rate = 0.0000625
        exploration_initial_eps = 1.
        exploration_fraction = .02
        train_freq = 4
        progress_bar = True
        gradient_steps = 1
        max_grad_norm = np.inf
        gamma = 0.99 
        policy_kwargs = dict(
            optimizer_class = optim.RMSprop,
            optimizer_kwargs = dict(alpha = 0.95,
                                  eps = 0.01,
                                  momentum = 0.95,
                                  centered = True))
        exploration_final_eps = .01
        target_update_interval = 30_000

    else:
        raise ValueError("This environment is currently not supported.")

    return  model_class, env, eval_env, dqn_policy, num_timesteps, reward_threshold, num_evals, callback_on_new_best, learning_rate, \
            learning_starts, n_eval_episodes, exploration_initial_eps, exploration_fraction, exploration_final_eps, batch_size, buffer_size, \
            policy_kwargs, max_grad_norm, train_freq, target_update_interval, gradient_steps, gamma, progress_bar