import os
import sys
import glob
import shutil
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import pickle
from utils.utils import create_folder_if_not_exist, create_folder_overwrite_if_exist
from environment.example.Env_Maze2d import Maze2D_ShowEpi
from environment.example.Env_FourRooms import FourRooms_ShowEpi

def load_data(env_name:str, data_name:str):
    with open(f'{base_path}/data/example/{env_name}/{data_name}.pkl', 'rb') as f:
        episodes = pickle.load(f)  
    return episodes

# 可视化文件夹
#ENV_NAME = 'Maze2d'
ENV_NAME = 'FourRooms'
create_folder_if_not_exist(f'{base_path}/visualize/dataset/{ENV_NAME}')

# 获取指定环境的所有数据集
file_list = glob.glob(os.path.join(f'{base_path}/data/example/{ENV_NAME}', '*.pkl'))
file_names = [os.path.basename(file) for file in file_list]

# 轨迹渲染器
if ENV_NAME == 'Maze2d':
    epi_render = Maze2D_ShowEpi()
elif ENV_NAME == 'FourRooms':
    epi_render = FourRooms_ShowEpi()
else:
    raise False

for file_name in file_names:
    data_name = file_name[:-4]
    create_folder_overwrite_if_exist(f'{base_path}/visualize/dataset/{ENV_NAME}/{data_name}')

    episodes = load_data(env_name=ENV_NAME, data_name=data_name)
    for i, epi in enumerate(episodes):
        obss = epi['observations']
        epi_render.show_episode(obss.tolist(), save_path=f'{base_path}/visualize/dataset/{ENV_NAME}/{data_name}/{i}_len={len(obss)}.png')

    

