# 加载来自 Pointer Network 的开源 TSP 数据集
import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import numpy as np
import os
import pickle
from tqdm import tqdm
from environment.used.Env_tsp_v1 import TSP_V1
from environment.used.BaseEnv_COP import DataProblem
from utils.utils import create_folder_if_not_exist, set_seed
from gym.utils.env_checker import check_env
import numpy as np
import pickle
from tqdm import tqdm

def load_problem_data(filename, num=None):
    problems = DataProblem(problem_list=[], answer_list=[])
    with open(filename, 'r') as f:
        if num is None:
            num = sum(1 for line in f)
            f.seek(0)

        for i, line in tqdm(enumerate(f), total=num):
            if i == num:
                break
            split_line = line.strip().split(' ')
            coord_strs = split_line[:split_line.index('output')]
            position = np.array([float(coord) for coord in coord_strs]).reshape(-1, 2)
            answer_strs = split_line[split_line.index('output') + 1:]
            real_answer = [int(answer)-1 for answer in answer_strs][:-1]

            observation = {}
            observation['position'] = position.astype(np.float32)
            observation['visited'] = np.array([1]+[0]*(position.shape[0]-1), dtype=np.int32)
            observation['first_index'] = np.array([0], dtype=np.int32)
            observation['current_index'] = np.array([0], dtype=np.int32)
            
            problems.problem_list.append(observation)
            problems.answer_list.append(real_answer)

    return problems

def load_PN_data(filename, env, epi_num=None):
    episodes = []
    with open(filename, 'r') as f:
        if epi_num is None:
            epi_num = sum(1 for line in f)
            f.seek(0)
        
        for i, line in tqdm(enumerate(f), total=epi_num, desc='Loading data'):
            if epi_num is not None and i == epi_num:
                break
            
            split_line = line.strip().split(' ')
            coord_strs = split_line[:split_line.index('output')]
            position = np.array([float(coord) for coord in coord_strs]).reshape(-1, 2)
            answer_strs = split_line[split_line.index('output') + 1:]
            real_answer = [int(answer)-1 for answer in answer_strs][:-1]

            problem = {}
            problem['position'] = position
            problem['visited'] = np.array([1]+[0]*(position.shape[0]-1), dtype=np.int32)
            problem['first_index'] = np.array([0], dtype=np.int32)
            problem['current_index'] = np.array([0], dtype=np.int32)
            
            # 重置环境
            observation, _ = env.reset(options={'problem_info': (None, problem, real_answer)})
            solution = env.real_answer
            assert solution[0] == 0
            assert solution == real_answer

            # 生成 MDP 轨迹
            obss, acts, rewards = [observation, ], [], []
            for action in solution[1:-1]:
                observation, reward, terminated, truncated, info = env.step(action)
                acts.append(action)
                rewards.append(reward)
                obss.append(observation)
                assert not (terminated or truncated)
            action = solution[-1]
            observation, reward, terminated, truncated, info = env.step(action)  
            assert reward == 1
            assert terminated and not truncated
            acts.append(action)
            rewards.append(reward)
            
            assert len(obss) == len(acts) == len(rewards)

            # 处理成 d4rl 格式保存
            obss_positions = np.vstack([obs['position'] for obs in obss])
            obss_visiteds = np.vstack([obs['visited'] for obs in obss])
            obss_first_idxs = np.array([obs['first_index'] for obs in obss])
            obss_current_idxs = np.array([obs['current_index'] for obs in obss])
            episode = {
                'prefix': None,
                'observations': {
                    'position': obss_positions,                     # (time_steps, 2*num_nodes)
                    'visited': obss_visiteds,                       # (time_steps, num_nodes)
                    'first_index': obss_first_idxs,                 # (time_steps, )
                    'current_index': obss_current_idxs              # (time_steps, )
                },
                'actions': np.array(acts).astype(np.int32),         # (time_steps, )
                'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
                'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
            }
            episodes.append(episode)
        env.close()

    return episodes

# 构造环境
num_nodes = 20
env = TSP_V1(num_nodes=num_nodes)
check_env(env)

# 初始化随机种子
seed = 100
env.action_space.seed(seed)
env.reset(seed=seed)
set_seed(seed)

# 用于保存数据的文件夹
create_folder_if_not_exist(f'{base_path}/data/used/TSP_V1')
epi_num = 5000                  # 生成训练轨迹数量
num_train = int(epi_num * 0.9)  # 划分训练轨迹和prompt轨迹

'''
# 加载评估问题
problems = load_problem_data(f'{base_path}/data/PointerNetwork/tsp{num_nodes}_test.txt', num=epi_num)
with open(f'{base_path}/data/used/TSP_V1/tsp{num_nodes}_problem.pkl', 'wb') as f:
    pickle.dump(problems, f)    
    print(f'eval_problems saved')
'''

# 加载轨迹数据
episodes = load_PN_data(f'{base_path}/data/PointerNetwork/tsp{num_nodes}.txt', env=env, epi_num=epi_num)
with open(f'{base_path}/data/used/TSP_V1/tsp{num_nodes}_train.pkl', 'wb') as f:
    pickle.dump(episodes[:num_train], f)    
    print(f'data_train saved')
with open(f'{base_path}/data/used/TSP_V1/tsp{num_nodes}_prompt.pkl', 'wb') as f:
    pickle.dump(episodes[num_train:], f)
    print(f'data_prompt saved')