import numpy as np
import torch
import os
import sys
import random    
import torch.multiprocessing as mp

import gym
from envs.env import *
from dqn.library import *

if __name__ == '__main__': 
    path = './dqn/data'
    env_key="MiniGrid-PickUpObj-v0"
    env = make_env(env_key)

    print('Loading ...')
    tasks_SOP, values_SOP = {}, {}
    path_models = './dqn/models/base_evfs'
    for name in os.listdir(path_models):
        n, e = name.split(".")
        if e == 'bin':
            tasks_SOP[n] = torch.load(path_models+'/'+name)
        else:
            values_SOP[n] = load(path_models+'/'+name, env)
    
    tasks = np.array([
            # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
            [0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0],
            [1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1],
            [0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0]
            ])

    t = 0
    task = tasks[t]
    goals = [env.all_goals[i] for i in np.where(task==1)[0]]
    env = make_env(env_key, goals=goals)
    print(goals)

    g = list(tasks_SOP['a'].keys())
    # task = {g[i]:0 for i in range(len(g))}
    # task['448c2227a13213ca0020050bf0185731'] = 1
    task = {g[i]:tasks[0][i] for i in range(len(g))}
    exp = task_exp(tasks_SOP, task)
    task_ = exp_task(tasks_SOP, exp)
    model = exp_value(values_SOP, exp)  
    print(exp)
    print(task == task_)
    print(list(task.values()))
    print(list(task_.values()))

    print('Visualizing ...')        
    max_episodes = 50000
    max_trajectory = 20
    for episode in range(max_episodes):
        obs = env.reset()
        goal = select_goal(model,obs)
        
        for _ in range(max_trajectory):
            env.render()
            action = select_action(model,obs,goal)
            obs, reward, done, _ = env.step(action)
            
            if done or env.window.closed:
                print(reward)
                break

        if env.window.closed:
            break

    # max_episodes=100#int(1e6)
    # mean_episodes=1000
    # num_runs = 1
    # data_SOP = np.zeros((num_runs,max_episodes//mean_episodes)) 
    # for i in range(num_runs):
    #     print('run',i)            
    #     task, model, eval_returns = train(env, learned=(tasks_SOP, values_SOP), max_episodes=max_episodes, mean_episodes=mean_episodes)
    #     data_SOP[i] = eval_returns
    
    #     torch.save(data_SOP, path+'/exp1_SOP.h5')
