from refinement.graph import Node, depth_first_traversal
from refinement.goal import Goal
from env.dirl_grid import RoomsEnv
from env.rooms_envs import GRID_PARAMS_LIST, MAX_TIMESTEPS, START_ROOM, FINAL_ROOM

from env.gridworldenv import ContinuousGridworld


start = Goal(lower_bound=[1, 1], upper_bound=[7, 7], index_range=[0, 1,])
end = Goal(lower_bound=[9, 9], upper_bound=[15, 15], index_range=[0, 1])

start_node = Node(start, False, False, "start")
# mid_node = Node(mid_region, True, False, "mid")
goal_node = Node(end, True, True, "goal")
start_node.add_child(goal_node)
# 
# env = RoomsEnv(GRID_PARAMS_LIST[1], START_ROOM[1], FINAL_ROOM[1], max_timesteps=MAX_TIMESTEPS[1])
train_env = ContinuousGridworld(
    custom_doors = [((0, 0), (0, 1)), ((0, 1), (0, 2)),  ((0, 2), (1, 2)), ((1, 1), (1, 2)), ((1, 0), (1, 1)), ((1, 0), (2, 0))],
    render_mode="rgb_array"
)
test_env = ContinuousGridworld(
    custom_doors = [((0, 0), (0, 1)), ((0, 1), (0, 2)),  ((0, 2), (1, 2)), ((1, 1), (1, 2)), ((1, 0), (1, 1)), ((1, 0), (2, 0))  ],
    render_mode="rgb_array"
)
    
def run_3grid(minimum_reach: float = 0.9, n_episodes: int = 100000,  n_episodes_test: int = 1000, path: str = ""):
    depth_first_traversal(start_node, train_env, test_env, minimum_reach, n_episodes, n_episodes_test, path)
    
    # depth_first_traversal(start_node, train_env, test_env, 0.9, 80000, 100, ".")

