# 加载来自 Pointer Network 的开源 TSP 数据集
import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import numpy as np
import os
import pickle
from tqdm import tqdm
from environment.used.Env_tsp_v3 import TSP_V3
from environment.used.BaseEnv_COP import DataProblem
from utils.utils import create_folder_if_not_exist, set_seed
from gym.utils.env_checker import check_env
import numpy as np
import pickle
from tqdm import tqdm

def load_problem_data(filename, num=None):
    problems = DataProblem(prefix_list=[], answer_list=[])
    with open(filename, 'r') as f:
        if num is None:
            num = sum(1 for line in f)
            f.seek(0)

        for i, line in tqdm(enumerate(f), total=num):
            if i == num:
                break
            split_line = line.strip().split(' ')
            coord_strs = split_line[:split_line.index('output')]
            position = np.array([float(coord) for coord in coord_strs]).reshape(-1, 2)
            answer_strs = split_line[split_line.index('output') + 1:]
            real_answer = [int(answer)-1 for answer in answer_strs][:-1]

            problems.prefix_list.append(position.astype(np.float32))            
            problems.answer_list.append(real_answer)

    return problems

def load_PN_data(filename, env, epi_num=None):
    episodes = []
    problems = DataProblem(prefix_list=[], answer_list=[])
    with open(filename, 'r') as f:
        if epi_num is None:
            epi_num = sum(1 for line in f)
            f.seek(0)
        
        for i, line in tqdm(enumerate(f), total=epi_num, desc='Loading data'):
            if epi_num is not None and i == epi_num:
                break
            
            split_line = line.strip().split(' ')
            coord_strs = split_line[:split_line.index('output')]
            position = np.array([float(coord) for coord in coord_strs]).reshape(-1, 2).astype(np.float32)
            answer_strs = split_line[split_line.index('output') + 1:]
            real_answer = [int(answer)-1 for answer in answer_strs][:-1]
        
            problems.prefix_list.append(position.astype(np.float32))            
            problems.answer_list.append(real_answer)

            # 重置环境
            observation, _ = env.reset(problem_info=(position, None, real_answer))
            solution = env.real_answer
            assert solution[0] == 0
            assert solution == real_answer

            # 生成 MDP 轨迹
            obss, acts, rewards = [observation, ], [], []
            for action in solution[1:-1]:
                observation, reward, terminated, truncated, info = env.step(action)
                acts.append(action)
                rewards.append(reward)
                obss.append(observation)
                assert not (terminated or truncated)
            action = solution[-1]
            observation, reward, terminated, truncated, info = env.step(action)  
            assert reward == 1
            assert terminated and not truncated
            acts.append(action)
            rewards.append(reward)
            
            assert len(obss) == len(acts) == len(rewards)

            # 处理成 d4rl 格式保存
            obss_visiteds = np.vstack([obs['visited'] for obs in obss])
            current_position = np.array([obs['current_position'] for obs in obss])
            episode = {
                'prefix': {
                    'position': position.flatten().astype(np.float32),
                },
                'observations': {
                    'visited': obss_visiteds,                       # (time_steps, num_nodes)
                    'current_position': current_position            # (time_steps, 2)
                },
                'actions': np.array(acts).astype(np.int32),         # (time_steps, )
                'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
                'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
            }
            episodes.append(episode)
        env.close()

    return episodes, problems

# 构造环境
num_nodes = 20
env = TSP_V3(num_nodes=num_nodes)
#check_env(env)

# 初始化随机种子
seed = 100
env.action_space.seed(seed)
env.reset(seed=seed)
set_seed(seed)

# 用于保存数据的文件夹
create_folder_if_not_exist(f'{base_path}/data/used/TSP_V3')

# 加载评估问题
problems = load_problem_data(f'{base_path}/data/PointerNetwork/tsp{num_nodes}_test.txt', num=None)
with open(f'{base_path}/data/used/TSP_V3/tsp{num_nodes}_problem.pkl', 'wb') as f:
    pickle.dump(problems, f)    
    print(f'eval_problems saved')

# 收集数据和评估问题
epi_num = 100000               # 生成训练轨迹数量
num_train = int(epi_num * 0.9)  # 划分训练轨迹和prompt轨迹

# 加载轨迹数据
episodes, problems = load_PN_data(f'{base_path}/data/PointerNetwork/tsp{num_nodes}.txt', env=env, epi_num=epi_num)
with open(f'{base_path}/data/used/TSP_V3/tsp{num_nodes}_train.pkl', 'wb') as f:
    pickle.dump(episodes[:num_train], f)    
    print(f'data_train saved')
with open(f'{base_path}/data/used/TSP_V3/tsp{num_nodes}_prompt.pkl', 'wb') as f:
    pickle.dump(episodes[num_train:], f)
    print(f'data_prompt saved')

# 对训练轨迹也生成问题，这些问题的评估性能应当接近最优
problems.prefix_list = problems.prefix_list[:num_train]
problems.answer_list = problems.answer_list[:num_train]
with open(f'{base_path}/data/used/TSP_V3/tsp{num_nodes}_train_problem.pkl', 'wb') as f:
    pickle.dump(problems, f)    
    print(f'train_eval_problems saved')