import os
import sys
gcn_base_path = os.path.abspath('/data1/XXX/graph-convnet-tsp')
sys.path.append(gcn_base_path)
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import pickle
import numpy as np

from environment.used.BaseEnv_COP import DataProblem, RawData
from eval import eval_main

from_data_dir = '/data1/XXX/gato-revise/toy-gato-data/data/used/TSP_V3/'
to_data_dir = '/data1/XXX/gato-revise/toy-gato-data/data/used/TSP_V5/'
node_num = 20

### prompt.pkl
with open(os.path.join(from_data_dir, f'tsp{node_num}_prompt.pkl'), 'rb') as f:
    prompt_data = pickle.load(f)
    print(f'prompt_data loaded, length: {len(prompt_data)}')
data_input = [(x['prefix']['position'], np.insert(x['actions'],0,0)) for x in prompt_data]  #(num*2,), (num,)
node_embeddings = eval_main(data_input)
for i in range(len(prompt_data)):
    prompt_data[i]['prefix']['node_embedding'] = node_embeddings[i]
    prompt_data[i]['prefix_masks']['node_embedding'] = prompt_data[i]['prefix_masks']['position'].repeat(8, axis=-1)
    total_pos = prompt_data[i]['prefix']['position'].reshape(1, 20, 2).repeat(19, axis=0)
    matches = (total_pos == prompt_data[i]['observations']['current_position'][:,None,:])[:,:,0]
    matches = np.where(matches)[1]
    prompt_data[i]['observations']['current_embedding'] = prompt_data[i]['prefix']['node_embedding'][matches]
    prompt_data[i]['prefix'].pop('position')
    prompt_data[i]['prefix_masks'].pop('position')
    prompt_data[i]['observations'].pop('current_position')
    prompt_data[i]['prefix']['node_embedding'] = prompt_data[i]['prefix']['node_embedding'].reshape(-1)
    a = 1

### problem.pkl
with open(os.path.join(from_data_dir, f'tsp{node_num}_problem.pkl'), 'rb') as f:
    problem_data = pickle.load(f)
    data_length = len(problem_data.answer_list)
    print(f'problem loaded, length: {data_length}')
data_input = [(problem_data.problem_list[i]['position'], problem_data.answer_list[i]) for i in range(data_length)]  #(num,2), (num,)
node_embeddings = eval_main(data_input)
for i in range(data_length):
    problem_data.problem_list[i]['node_embedding'] = node_embeddings[i]

problem_data.problem_list = problem_data.problem_list+problem_data.problem_list
problem_data.answer_list = problem_data.answer_list+problem_data.answer_list

### train.pkl
with open(os.path.join(from_data_dir, f'tsp{node_num}_train.pkl'), 'rb') as f:
    train_data = pickle.load(f)
    print(f'train_data loaded, length: {len(train_data)}')
data_input = [(x['prefix']['position'], np.insert(x['actions'],0,0)) for x in train_data]  #(num*2,), (num,)
node_embeddings = eval_main(data_input)
for i in range(len(train_data)):
    train_data[i]['prefix']['node_embedding'] = node_embeddings[i]
    train_data[i]['prefix_masks']['node_embedding'] = train_data[i]['prefix_masks']['position'].repeat(8, axis=-1)
    total_pos = train_data[i]['prefix']['position'].reshape(1, 20, 2).repeat(19, axis=0)
    matches = (total_pos == train_data[i]['observations']['current_position'][:,None,:])[:,:,0]
    matches = np.where(matches)[1]
    train_data[i]['observations']['current_embedding'] = train_data[i]['prefix']['node_embedding'][matches]
    train_data[i]['prefix'].pop('position')
    train_data[i]['prefix_masks'].pop('position')
    train_data[i]['observations'].pop('current_position')
    train_data[i]['prefix']['node_embedding'] = node_embeddings[i].reshape(-1)




with open(os.path.join(to_data_dir, f'tsp{node_num}_train.pkl'), 'wb') as f:
    pickle.dump(train_data, f)
    print(f'train_data saved')

with open(os.path.join(to_data_dir, f'tsp{node_num}_problem.pkl'), 'wb') as f:
    pickle.dump(problem_data, f)
    print(f'problem saved')

with open(os.path.join(to_data_dir, f'tsp{node_num}_prompt.pkl'), 'wb') as f:
    pickle.dump(prompt_data, f)
    print(f'prompt_data saved')
