import time
import utils
import argparse
import datetime
import numpy as np
import torch, gc
import wandb, pdb

from rl_algorithm.ddqn.agent import DDQNAgent

gc.collect()
torch.cuda.empty_cache()

# Parse arguments
parser = argparse.ArgumentParser()

# General parameters
parser.add_argument(
    "--env", required=True, help="name of the environment to train on (REQUIRED)"
)
parser.add_argument(
    "--model", default=None, help="name of the model (default: {ENV}_{ALGO}_{TIME})"
)
parser.add_argument(
    "--trains", type=int, default=5, help="number of trains (default: 10)"
)
parser.add_argument("--seed", type=int, default=-1, help="specific seed")
parser.add_argument(
    "--save-interval",
    type=int,
    default=25,
    help="number of updates between two saves (default: 50, \
                        0 means no saving)",
)
parser.add_argument(
    "--frames",
    type=int,
    default=2 * 10**6,
    help="number of frames of training (default: 1e7)",
)

# Parameters for main algorithm
parser.add_argument(
    "--max-memory",
    type=int,
    default=500000,
    help="Maximum experiences stored (default: 100000)",
)
parser.add_argument(
    "--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)"
)
parser.add_argument(
    "--exploration_type",
    type=str,
    default="epsilon",
    help="exploration_type (default: epsilon)",
)
parser.add_argument(
    "--argmax",
    action="store_true",
    default=False,
    help="action with highest probability is selected",
)
parser.add_argument(
    "--procs", type=int, default=1, help="number of processes (default: 64)"
)
parser.add_argument("--seed_type", type=int, default=0, help="seed 0-4 or seed 5-9")
parser.add_argument("--algorithm", type=str, default="ddqn", help="ddqn")
parser.add_argument("--reset_ver", type=int, default=-10, help="reset version")
#-10: no reset # 0: reset last 2 layers # 2: reset all layers
parser.add_argument("--reset_itv", type=int, default=200000, help="reset interval")
parser.add_argument("--reset_rr", type=int, default=1, help="replay rataio")
parser.add_argument("--reset_multi", type=int, default=4, help="number of ensemble")
parser.add_argument("--reset_ww", type=float, default=50, help="beta")
parser.add_argument("--train_interval", type=int, default=1, help="train_interval")
parser.add_argument("--no_reset", type=int, default=0, help="replay rataio")



args = parser.parse_args()

utils.register_minigrid_envs()

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set run dir
date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
default_model_name = "{}_{}_{}_{}".format(
    args.env, args.algorithm, args.exploration_type, date
)

model_name = args.model or default_model_name
model_dir = utils.get_model_dir(model_name)
result_dir = utils.get_model_dir("{}_plot_raw_data".format(args.env))

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train model
start_time = time.time()

return_per_frame, test_return_per_frame = [], []

for t in range(args.trains):
    print("-------trains: {}---------".format(t))
    if args.trains == 5:
        seed = 10 * args.seed_type + t
        args.seed = t
    elif args.trains == 1 and args.seed > -1:
        seed = args.seed

    utils.seed(seed)
    env = utils.make_env(args.env, seed)
    eval_env = utils.make_env(args.env, seed)
    
    # pdb.set_trace()

    return_per_frame_, test_return_per_frame_ = [], []
    num_frames = 0
    episode = 0

    # Load observations preprocessor
    obs_space, preprocess_obss = utils.get_obss_preprocessor(
        env.observation_space
    )


    if args.algorithm == "ddqn":
        agent = DDQNAgent(
            env=env,
            eval_env=eval_env,
            device=device,
            args=args,
            preprocess_obs=preprocess_obss,
            model_dir=model_dir,
            env_size=0
        )


    while num_frames < args.frames:
        update_start_time = time.time()
        logs = agent.collect_experiences(
            start_time=start_time,
            episode=episode,
            num_frames=num_frames,
            return_per_frame_=return_per_frame_,
            test_return_per_frame_=test_return_per_frame_,
        )
        update_end_time = time.time()

        num_frames = logs["num_frames"]

        episode += 1

    return_per_frame.append(return_per_frame_)
    test_return_per_frame.append(test_return_per_frame_)