import os
import sys
import glob
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import pickle
import numpy as np
from utils.utils import load_data, create_folder_if_not_exist, create_folder_overwrite_if_exist

# 可视化文件夹
ENV_NAME = 'TSP_V1'
create_folder_if_not_exist(f'{base_path}/visualize/dataset/{ENV_NAME}')

# 获取指定环境的所有数据集   
file_list = glob.glob(os.path.join(f'{base_path}/data/used/{ENV_NAME}', '*.pkl'))
file_names = [os.path.basename(file) for file in file_list]
#file_names = ['tsp5_eval.pkl', 'tsp5_prompt.pkl', 'tsp5_train.pkl']

# 写数据集 log 文件
create_folder_overwrite_if_exist(f'{base_path}/visualize/dataset/{ENV_NAME}')
for file_name in file_names:
    data_name = file_name[:-4]
    # 可视化评估问题
    if data_name.endswith('problem'):
        problems = load_data(env_name=ENV_NAME, data_name=data_name)    
        with open(f'{base_path}/visualize/dataset/{ENV_NAME}/{data_name}.txt', 'w') as file:
            #for i, obs in enumerate(problems):
            for i in range(len(problems.answer_list)):
                file.write('-'*25+f' problem-{i} '+'-'*25+'\n')
                obs = problems.problem_list[i]
                answer = problems.answer_list[i]
                position = obs['position'].reshape((-1, 2))
                num_nodes = position.shape[0]
                visited = obs['visited']
                first_idx = obs['first_index']
                current_idx = obs['current_index']
                assert current_idx.item() == 0
                assert first_idx.item() == 0
                assert np.array_equal(visited, np.array([1]+[0]*(num_nodes-1)))
                file.write(f'city position:\n{position}\n\n')                
                file.write(f'solution: {answer}\n\n')                    

    # 可视化训练序列
    else:
        episodes = load_data(env_name=ENV_NAME, data_name=data_name)
        with open(f'{base_path}/visualize/dataset/{ENV_NAME}/{data_name}.txt', 'w') as file:
            for i, epi in enumerate(episodes):
                file.write('-'*25+f' episode-{i} '+'-'*25+'\n')
                acts = epi['actions']
                rewards = epi['rewards']
                obss = epi['observations']
                positions = obss['position']
                visiteds = obss['visited']
                first_idxs = obss['first_index']
                current_idxs = obss['current_index']
                assert (first_idxs==0).all()
                #distance, solution = TSP_lkh(positions[0].reshape((-1, 2)))

                file.write(f'city position:\n{positions[0].reshape((-1, 2))}\n\n')                
                #file.write(f'solution: {solution}\n\n')   
                for t in range(len(rewards)):
                    file.write(f'visited:    \t{visiteds[t]}\n')
                    file.write(f'current idx:\t{current_idxs[t].item()}\n')
                    file.write(f'take action:\t{acts[t]}\n')
                    file.write(f'get reward: \t{rewards[t]}\n\n')
        
    print(f'{data_name}.txt saved')