import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

from environment.used.Env_tsp_v1 import TSP_V1
from environment.used.BaseEnv_COP import DataProblem
from utils.utils import create_folder_if_not_exist, set_seed
from gym.utils.env_checker import check_env
import numpy as np
import pickle
from tqdm import tqdm

# 构造环境
num_nodes = 5
env = TSP_V1(num_nodes=num_nodes) 
check_env(env.unwrapped) # 检查环境是否符合 gym 规范

# 初始化随机种子
seed = 100
env.action_space.seed(seed)
env.reset(seed=seed)
set_seed(seed)

# 用于保存数据的文件夹
create_folder_if_not_exist(f'{base_path}/data/used/TSP_V1')

# 收集数据和评估问题
epi_num = 5000                 # 生成训练轨迹数量
eval_prob_num = 200             # 用于评估的问题数量
num_train = int(epi_num * 0.9)  # 划分训练轨迹和prompt轨迹

'''
# 生成并保存评估问题
problems = DataProblem(problem_list=[], answer_list=[])
for i in tqdm(range(eval_prob_num), total=eval_prob_num, desc='collecting eval problems'):
    # 重置环境, 生成随机的 TSP 问题，问题设定应当完全体现在在初始观测中
    observation, info = env.reset()
    problems.problem_list.append(observation)
    problems.answer_list.append(env.real_answer)
with open(f'{base_path}/data/used/TSP_V1/tsp{num_nodes}_eval.pkl', 'wb') as f:
    pickle.dump(problems, f)    
    print(f'eval_problems saved')
'''

# 生成并保存训练轨迹数据
episodes = []
for i in tqdm(range(epi_num), total=epi_num, desc='collecting data'):
    # 重置环境, 生成随机的 tsp 问题，LKH 法求解
    observation, info = env.reset()
    #position = observation['position'].reshape((num_nodes, 2))
    #distance, solution = TSP_lkh(position)
    solution = env.real_answer  # 放弃二次检查来加速生成
    assert solution[0] == 0
    assert solution == env.real_answer

    # 生成 MDP 轨迹
    obss, acts, rewards = [observation, ], [], []
    for action in solution[1:-1]:
        observation, reward, terminated, truncated, info = env.step(action)
        acts.append(action)
        rewards.append(reward)
        obss.append(observation)
        assert not (terminated or truncated)
    action = solution[-1]
    observation, reward, terminated, truncated, info = env.step(action)  
    assert reward == 1
    assert terminated and not truncated
    acts.append(action)
    rewards.append(reward)
    
    assert len(obss) == len(acts) == len(rewards)

    # 处理成 d4rl 格式保存
    obss_positions = np.vstack([obs['position'] for obs in obss])
    obss_visiteds = np.vstack([obs['visited'] for obs in obss])
    obss_first_idxs = np.array([obs['first_index'] for obs in obss])
    obss_current_idxs = np.array([obs['current_index'] for obs in obss])
    episode = {
        'prefix': None,
        'observations': {
            'position': obss_positions,                     # (time_steps, 2*num_nodes)
            'visited': obss_visiteds,                       # (time_steps, num_nodes)
            'first_index': obss_first_idxs,                 # (time_steps, )
            'current_index': obss_current_idxs              # (time_steps, )
        },
        'actions': np.array(acts).astype(np.int32),         # (time_steps, )
        'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
        'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
    }
    episodes.append(episode)
env.close()

'''
with open(f'{base_path}/data/used/TSP_V1/tsp{num_nodes}_train.pkl', 'wb') as f:
    pickle.dump(episodes[:num_train], f)    
    print(f'data_train saved')
with open(f'{base_path}/data/used/TSP_V1/tsp{num_nodes}_prompt.pkl', 'wb') as f:
    pickle.dump(episodes[num_train:], f)
    print(f'data_prompt saved')
'''
