
import torch
import numpy as np
import os
import itertools
import string

from dqn.library import *
from envs.env import *
from matplotlib import pyplot as plt

if __name__ == '__main__':
    path = './dqn/models'
    letters = string.ascii_lowercase
    env_key="MiniGrid-PickUpObj-v0"
    env = make_env(env_key)

    print('Loading ...')
    max_model = {}
    for name in os.listdir('./dqn/object'):
        print(name)
        goal = name.split(".")[-2].split("_")[-1]
        model = load('./dqn/object/'+name, env)
        max_model[goal] = model[goal]
            
    base = False
    tasks = [[]]
    if base:        
        path += "/base_evfs" 
        bases = get_bases(len(env.all_goals))
        tasks += [[env.all_goals[i] for i in np.where(task==1)[0]] for task in bases]
    else:
        path += "/not_base_evfs"
        tasks += [ list(itertools.product([_],OBJ_TYPES)) for _ in OBJ_COLORS]
        tasks += [ list(itertools.product(OBJ_COLORS, [_])) for _ in OBJ_TYPES]

    print('Generating ...')
    for t in range(len(tasks)):
        model = {}
        task = {}
        vgoal = to_hash(env.reset()*0)
        model[vgoal] = max_model[vgoal].state_dict()
        print(tasks[t], t)

        env = make_env(env_key, goals=tasks[t])
        while len(model) < 16:
            obs = env.reset()
            
            goal = select_goal(max_model,obs)
            for _ in range(100):
                action = select_action(max_model,obs,goal)
                obs_, reward, done, _ = env.step(action)

                if done:
                    goal = to_hash(obs_)
                    goal_ = goal
                    if not reward > 0:
                        goal_ = to_hash(obs_*0)
                    model[goal] = max_model[goal_].state_dict()
                    task[goal] = (reward > 0)+0
                    break
                obs = obs_
        
        save(path+"/"+letters[t]+".dqn", model)
        save(path+"/"+letters[t]+".bin", task)
        print(list(task.values()))
        print(t,'saved')
