import numpy as np
import torch
import os
import itertools
import sys

import gym
from envs.env import *
from dqn.libraries import *
from matplotlib import pyplot as plt

path = './dqn/models'

if __name__ == '__main__':
    env = make_env()
    goals = env.get_goals_imgs()
    print(len(goals))
    fig = plt.figure()
    plt.xticks([])
    plt.yticks([])
    # print(goals[0].shape, goals[0][0].shape)
    for i in range(len(goals)):
        x = goals[i].reshape(32*32,3)
        y = np.sum(x,axis=1)
        x[np.where(y==0)[0]] += 255
        goals[i] = x.reshape(32,32,3)
        plt.imshow(goals[i])
        fig.savefig("dqn/goals/"+str(i)+".pdf", bbox_inches='tight')

