
import torch
import numpy as np
import os
import itertools
import string

from dqn.library import *
from envs.env import *
from matplotlib import pyplot as plt

if __name__ == '__main__':
    path = './dqn/models'
    letters = string.ascii_lowercase
    env_key="MiniGrid-PickUpObj-v0"
    env = make_env(env_key)

    print('Loading ...')
    max_model = {}
    for name in os.listdir('./dqn/object'):
        print(name)
        goal = name.split(".")[-2].split("_")[-1]
        model = load('./dqn/object/'+name, env)
        max_model[goal] = model[goal]
            
    base = False
    tasks = [[]]
    if base:        
        path += "/base_sfs" 
        bases = np.eye(len(env.all_goals))
        tasks += [[env.all_goals[i] for i in np.where(task==1)[0]] for task in bases]
    else:
        path += "/not_base_sfs"
        tasks += [ list(itertools.product([_],OBJ_TYPES)) for _ in OBJ_COLORS]
        tasks += [ list(itertools.product(OBJ_COLORS, [_])) for _ in OBJ_TYPES]

    vgoal = to_hash(env.reset()*0)
    tasks_goals = {vgoal:'0'}
    print('Generating ...')
    for t in range(len(tasks)):
        task = {}
        model = {}
        model['0'] = max_model[vgoal].state_dict()
        print(tasks[t], t)

        env = make_env(env_key, goals=tasks[t])
        while len(model) < 16:
            obs = env.reset()
            
            for _ in range(100):
                action = select_action(max_model,obs)
                obs_, reward, done, _ = env.step(action)

                if done:
                    goal = to_hash(obs_)
                    if goal not in tasks_goals:
                        tasks_goals[goal] = str(len(tasks_goals))
                    d = tasks_goals[goal]
                    task[goal] = reward
                    if reward > 0:
                        model[d] = max_model[goal].state_dict()                        
                    else:
                        model[d] = max_model[vgoal].state_dict()
                    break
                obs = obs_
        
        print(list(model.keys()))
        save(path+"/"+str(t)+".dqn", model)
        save(path+"/"+str(t)+".bin", task)
        print(list(task.values()))
        print(t,'saved')
