import numpy as np
import torch
import os
import random    

import gym
from envs.env import *
from dqn.libraries import *

if __name__ == '__main__':    
    path = './dqn/data'
    if not os.path.exists(path):
        os.makedirs(path)
    
    ######################################
    env = make_env()
    all_goals = env.all_goals
    n_goals = len(env.all_goals)
    ###################################### 
    # Sanity checks
    tasks = [sample_random(n_goals) for _ in range(20)]
    task = sample_random(n_goals)
    exp = task_exp(tasks, task, n_goals)
    task_ = exp_task(tasks, exp, n_goals)
    print(str(task))
    print(str(task_))
    print(exp)
    print(task==task_)
    replay_buffer = ReplayBuffer(int(1e6), 32, 4)
    virtual_goal = env.reset()*0
    replay_buffer.goals.append(virtual_goal)
    replay_buffer.goals_hash.append(to_hash(virtual_goal))
    evfs = []
    for task in tasks:
        goals = []
        for g in range(n_goals):
            if task[g]==1:
                goals.append(all_goals[g])
        env = make_env(goals=goals)
        evfs.append(train(env, type_='evf',max_timesteps=100, replay_buffer=replay_buffer)[0])
    max_evf = get_max(evfs)
    print('SOP')
    sop_evf = exp_evf(evfs, exp, n_goals, max_evf)
    