import numpy as np
import torch
import os
import sys
import random    
import torch.multiprocessing as mp
import argparse

import gym
from envs.env import *
from dqn.library import *


parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='.',
    help="path"
)
parser.add_argument(
    '--exp',
    default='sop',
    help="Room number"
)
parser.add_argument(
    '--eval_type',
    default=None,
    help="Room number"
)
parser.add_argument(
    '--task',
    type=int,
    default=0,
    help="Object type"
)
parser.add_argument(
    '--run',
    type=int,
    default=0,
    help="Object type"
)
args = parser.parse_args()


if __name__ == '__main__': 
    mp.set_start_method('spawn', force=True)
    path = '{}/dqn/data'.format(args.path)
    path_models = '{}/dqn/models/base_evfs'.format(args.path)
    if args.exp == "sf":
        path_models = '{}/dqn/models/base_sfs'.format(args.path)
    data_path = '{}/exp1_{}_{}_{}_{}.h5'.format(path,args.exp,args.eval_type,args.task,args.run)   
    print('data_path: ',data_path)

    env_key="MiniGrid-PickUpObj-v0"
    env = make_env(env_key)

    print('Loading ...')
    tasks_sop, values_sop = {}, {}
    for name in os.listdir(path_models):
        n, e = name.split(".")
        if e == 'bin':
            tasks_sop[n] = torch.load(path_models+'/'+name)
        else:
            values_sop[n] = load(path_models+'/'+name, env)
    
    tasks = np.array([
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0],
            [1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1],
            [0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0]
            ])

    t = args.task
    task = tasks[t]
    goals = [env.all_goals[i] for i in np.where(task==1)[0]]
    env = make_env(env_key, goals=goals)
    print(goals)

    eval_type = args.eval_type
    eval_interval=100
    max_episodes=int(1e5)
    eps_timesteps=int(1e5)
    eps_initial = 0.5
    replay_buffer_size=int(1e4)
    gamma=0.95
    batch_size=256
    learning_rate=1e-3
    if args.exp == 'sop':
        task, model, data = train_sop(env, learned=(tasks_sop, values_sop), max_episodes=max_episodes, eps_timesteps=eps_timesteps, replay_buffer_size=replay_buffer_size,
                                        eps_initial=eps_initial, batch_size=batch_size, learning_rate=learning_rate, gamma=gamma, eval_type=eval_type, save_logs=data_path)
        torch.save(data, data_path)
    elif args.exp == 'sf':
        task, model, data = train_sf(env, learned=(tasks_sop, values_sop), max_episodes=max_episodes, eps_timesteps=eps_timesteps, replay_buffer_size=replay_buffer_size, 
                                        eps_initial=eps_initial, batch_size=batch_size, learning_rate=learning_rate, gamma=gamma, eval_type=eval_type, save_logs=data_path)
        torch.save(data, data_path)
