import time
import numpy as np
import torch

from src.utils.loading import load_model, get_mazes
from src.utils.plotting import plot_mazes

# Load models
for model_name in ['dt_net', 'pi_net_1']:
    inputs, solutions = get_mazes(dataset='easy-to-hard-data', maze_size=9, num_mazes=10)
    model = load_model(model_name)
    predictions = model.predict(inputs, iters=40)
    plot_mazes(inputs, predictions=predictions, solutions=solutions, file_name=f'outputs/mazes/{model_name}.pdf')

