# 加载来自 Pointer Network 的开源 TSP 数据集
# 现在此方法已废弃，因为这些数据质量和 AM 论文不同
import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import numpy as np
import os
from tqdm import tqdm
from environment.used.Env_tsp_v2 import TSP_V2
from environment.used.BaseEnv_COP import DataProblem
from utils.utils import create_folder_if_not_exist, set_seed
from gym.utils.env_checker import check_env
import numpy as np
from utils.COP_slover import calc_tsp_distance
from tqdm import tqdm

def load_problem_data(filename, num=None):
    problems = DataProblem(problem_list=[], answer_list=[])
    with open(filename, 'r') as f:
        if num is None:
            num = sum(1 for line in f)
            f.seek(0)

        for i, line in tqdm(enumerate(f), total=num):
            if i == num:
                break
            split_line = line.strip().split(' ')
            coord_strs = split_line[:split_line.index('output')]
            position = np.array([float(coord) for coord in coord_strs]).reshape(-1, 2)
            answer_strs = split_line[split_line.index('output') + 1:]
            real_answer = [int(answer)-1 for answer in answer_strs][:-1]

            observation = {}
            observation['current_position'] = position[0].copy().astype(np.float32)
            position[0] = 0
            observation['position'] = position.astype(np.float32)
            
            problems.problem_list.append(observation)
            problems.answer_list.append(real_answer)

    return problems

def load_PN_data(filename, env, epi_num=None):
    returns = []
    distances = []
    episodes = []
    problems = DataProblem(problem_list=[], answer_list=[])
    with open(filename, 'r') as f:
        if epi_num is None:
            epi_num = sum(1 for line in f)
            f.seek(0)
        
        with tqdm(total=epi_num, desc=f'Loading data') as pbar:
            for i, line in enumerate(f):
                if epi_num is not None and i == epi_num:
                    break
                
                split_line = line.strip().split(' ')
                coord_strs = split_line[:split_line.index('output')]
                position = np.array([float(coord) for coord in coord_strs]).reshape(-1, 2)
                answer_strs = split_line[split_line.index('output') + 1:]
                real_answer = [int(answer)-1 for answer in answer_strs][:-1]

                problem = {}
                problem['current_position'] = position[0].copy().astype(np.float32)
                problem['position'] = position.astype(np.float32)
            
                problems.problem_list.append(problem)
                problems.answer_list.append(real_answer)

                # 重置环境
                observation, _  = env.reset(options={'problem_info':(None, problem, real_answer)})
                solution = env.real_answer
                assert solution[0] == 0
                assert solution == real_answer
                dis = calc_tsp_distance(position, real_answer)
                distances.append(dis)

                '''
                # 生成 MDP 轨迹
                obss, acts, rewards = [observation, ], [], []
                for action in solution[1:-1]:
                    observation, reward, terminated, truncated, info = env.step(action)
                    acts.append(action)
                    rewards.append(reward)
                    obss.append(observation)
                    assert not (terminated or truncated)
                action = solution[-1]
                observation, reward, terminated, truncated, info = env.step(action)  
                assert reward == 1
                assert terminated and not truncated
                acts.append(action)
                rewards.append(reward)
                
                assert len(obss) == len(acts) == len(rewards)

                # 处理成 d4rl 格式保存
                obss_positions = np.vstack([obs['position'] for obs in obss])
                obss_visiteds = np.vstack([obs['visited'] for obs in obss])
                current_position = np.array([obs['current_position'] for obs in obss])
                episode = {
                    'prefix': None,
                    'observations': {
                        'position': obss_positions,                     # (time_steps, 2*num_nodes)
                        'visited': obss_visiteds,                       # (time_steps, num_nodes)
                        'current_position': current_position            # (time_steps, 2)
                    },
                    'actions': np.array(acts).astype(np.int32),         # (time_steps, )
                    'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
                    'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
                }
                episodes.append(episode)
                '''

                info = {'obj': f'{np.mean(distances):.2f}'}
                pbar.set_postfix(info)
                pbar.update()
        env.close()

    return episodes, problems

# 构造环境
num_nodes = 20
env = TSP_V2(num_nodes=num_nodes)
check_env(env)

# 初始化随机种子
seed = 100
env.action_space.seed(seed)
env.reset(seed=seed)
set_seed(seed)

# 用于保存数据的文件夹
create_folder_if_not_exist(f'{base_path}/data/used/TSP_V2')

# 加载评估问题
#problems = load_problem_data(f'{base_path}/data/PointerNetwork/tsp{num_nodes}_test.txt', num=None)
'''
with open(f'{base_path}/data/used/TSP_V2/tsp{num_nodes}_eval.pkl', 'wb') as f:
    pickle.dump(problems, f)    
    print(f'eval_problems saved')
'''

# 收集数据和评估问题
#epi_num = 1000000                 # 训练轨迹数量
epi_num = 10000
num_train = int(epi_num * 0.9)  # 划分训练轨迹和prompt轨迹

# 加载轨迹数据
episodes, problems = load_PN_data(f'{base_path}/data/PointerNetwork/tsp{num_nodes}.txt', env=env, epi_num=epi_num)
'''
with open(f'{base_path}/data/used/TSP_V2/tsp{num_nodes}_train.pkl', 'wb') as f:
    pickle.dump(episodes[:num_train], f)    
    print(f'data_train saved')
with open(f'{base_path}/data/used/TSP_V2/tsp{num_nodes}_prompt.pkl', 'wb') as f:
    pickle.dump(episodes[num_train:], f)
    print(f'data_prompt saved')
'''
# 对训练轨迹也生成问题，这些问题的评估性能应当接近最优
problems.problem_list = problems.problem_list[:num_train]
problems.answer_list = problems.answer_list[:num_train]
'''
with open(f'{base_path}/data/used/TSP_V2/tsp{num_nodes}_train_eval.pkl', 'wb') as f:
    pickle.dump(problems, f)    
    print(f'train_eval_problems saved')
'''