import numpy as np
import torch
import os
import sys
import random    
import torch.multiprocessing as mp
import argparse

import gym
from envs.env import *
from dqn.library import *


parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='.',
    help="path"
)
parser.add_argument(
    '--exp',
    default='dqn',
    help="Room number"
)
parser.add_argument(
    '--task',
    type=int,
    default=0,
    help="Object type"
)
parser.add_argument(
    '--run',
    type=int,
    default=0,
    help="Object type"
)
args = parser.parse_args()


if __name__ == '__main__': 
    mp.set_start_method('spawn', force=True)
    path = '{}/dqn/data'.format(args.path)
    data_path = '{}/exp_{}_{}_{}.h5'.format(path,args.exp,args.task,args.run)   
    print('data_path: ',data_path)

    env_key="MiniGrid-PickUpObj-v0"
    env = make_env(env_key)

    tasks = np.array([
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0],
            [1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1],
            [0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0]
            ])

    t = args.task
    task = tasks[t]
    goals = [env.all_goals[i] for i in np.where(task==1)[0]]
    env = make_env(env_key, goals=goals)
    print(goals)

    eval_interval=100
    max_episodes=int(1e5)
    eps_timesteps=int(1e5)
    eps_initial = 0.5
    replay_buffer_size=int(1e4)
    gamma=0.95
    batch_size=256
    learning_rate=1e-3
    
    task, model, data = train_dqn(env, max_episodes=max_episodes, eps_timesteps=eps_timesteps, replay_buffer_size=replay_buffer_size,
                                    eps_initial=eps_initial, batch_size=batch_size, learning_rate=learning_rate, gamma=gamma, save_logs=data_path)
    torch.save(data, data_path)
