import os
import random
import argparse

import gym
import numpy as np
import torch
from dqn import DQNAgent
import os
from cartpole import CartPoleEnv


seed = 777

# Initialize the seed
def seed_torch(seed):
    torch.manual_seed(seed)
    if torch.backends.cudnn.enabled:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

np.random.seed(seed)
random.seed(seed)
seed_torch(seed)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--is_test", action="store_true", help="Test mode, do not train")
    parser.add_argument("--is_poison", action="store_true", help="Poison mode, training backdoored model")
    args = parser.parse_args()
    is_test = args.is_test
    is_poison = args.is_poison

    # parameters
    num_frames = 10000 
    memory_size = 1000 
    batch_size = 128
    target_update = 100
    folder_name = "cartPole_exp"

    if not is_test:    # train
        env = gym.make("CartPole-v1", max_episode_steps=500, render_mode="rgb_array")
        agent = DQNAgent(env, memory_size, batch_size, target_update, seed, folder_name=folder_name, is_test=is_test, is_poison=is_poison)
        agent.train(num_frames)
    else:
        # test
        is_record_video = True  # Store video to folder /cartPole_exp, or just show by gym

        if is_record_video:
            env = CartPoleEnv(render_mode="rgb_array", max_episode_steps=500)
            # env = gym.make("CartPole-v1", max_episode_steps=500, render_mode="rgb_array")
        else:
            env = CartPoleEnv(render_mode="human", max_episode_steps=500)
            # env = gym.make("CartPole-v1", max_episode_steps=500, render_mode="human")
        agent = DQNAgent(env, memory_size, batch_size, target_update, seed, folder_name=folder_name)
        agent.test(is_record_video=is_record_video, max_episode_steps=500)