import os
import sys
import random
import numpy as np
import torch
import gymnasium as gym
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)

from algorithms import ATARI_AGENTS
from utils.atari_config import parse_args
from utils.utils import Logger


def make_env(env_name, seed, resize=84):
    def thunk():
        env = gym.make(env_name)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        if len(env.observation_space.shape):  # pixel obs
            env = gym.wrappers.ResizeObservation(env, (resize, resize))
            env = gym.wrappers.GrayScaleObservation(env)
            env = gym.wrappers.FrameStack(env, 4)
        env.action_space.seed(seed)
        return env

    return thunk


def main(args, stdout):
    sys.stdout = Logger(stdout, args.log_path)
    print("============================================================")
    print("saving at:", args.save_path, args.time)
    print("============================================================")
    # create train env and eval env
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env, args.seed + i, args.resize) for i in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
    eval_env = gym.vector.SyncVectorEnv(
        [make_env(args.env, args.seed, args.resize)]
    )

    args.device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.cuda_deterministic

    # create agent
    agent = ATARI_AGENTS[args.algo](args, envs, eval_env)
    avg_reward, std_reward = agent.run()
    print("============================================================")
    print("saving at:", args.save_path, "avg reward:", avg_reward, std_reward)
    print("============================================================")
    sys.stdout.close()


if __name__ == "__main__":
    train_times = 2
    stdout = sys.stdout
    for t in range(train_times):
        args = parse_args()
        main(args, stdout)
