import os
import sys
gcn_base_path = os.path.abspath('/data1/XXX/graph-convnet-tsp')
sys.path.append(gcn_base_path)
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import pickle
import numpy as np

from environment.used.BaseEnv_COP import DataProblem, RawData


from_data_dir = '/data1/XXX/gato-revise/toy-gato-data/data/used/TSP_V3/'
to_data_dir = '/data1/XXX/gato-revise/toy-gato-data/data/used/ATSP_V2/'
node_num = 20


### prompt.pkl
with open(os.path.join(from_data_dir, f'tsp{node_num}_prompt.pkl'), 'rb') as f:
    prompt_data = pickle.load(f)
    print(f'prompt_data loaded, length: {len(prompt_data)}')

node_embeddings = []
for i in range(len(prompt_data)):
    positions = prompt_data[i]['prefix']['position'].reshape(-1, 2)
    matrix = np.linalg.norm(positions[:,None,:] - positions[None,:,:], axis=-1)
    node_embeddings.append(matrix)

for i in range(len(prompt_data)):
    prompt_data[i]['prefix']['node_embedding'] = node_embeddings[i]
    prompt_data[i]['prefix_masks']['node_embedding'] = prompt_data[i]['prefix_masks']['position'].repeat(10, axis=-1)
    total_pos = prompt_data[i]['prefix']['position'].reshape(1, 20, 2).repeat(19, axis=0)
    matches = (total_pos == prompt_data[i]['observations']['current_position'][:,None,:])[:,:,0]
    matches = np.where(matches)[1]
    prompt_data[i]['observations']['current_embedding'] = prompt_data[i]['prefix']['node_embedding'][matches]
    prompt_data[i]['prefix'].pop('position')
    prompt_data[i]['prefix_masks'].pop('position')
    prompt_data[i]['observations'].pop('current_position')
    prompt_data[i]['prefix']['node_embedding'] = prompt_data[i]['prefix']['node_embedding'].reshape(-1)
    a = 1

### problem.pkl
with open(os.path.join(from_data_dir, f'tsp{node_num}_problem.pkl'), 'rb') as f:
    problem_data = pickle.load(f)
    data_length = len(problem_data.answer_list)
    print(f'problem loaded, length: {data_length}')
node_embeddings = []
for i in range(len(problem_data.problem_list)):
    positions = problem_data.problem_list[i]['position'].reshape(-1, 2)
    matrix = np.linalg.norm(positions[:,None,:] - positions[None,:,:], axis=-1)
    node_embeddings.append(matrix)
for i in range(data_length):
    problem_data.problem_list[i]['node_embedding'] = node_embeddings[i]



### train.pkl
with open(os.path.join(from_data_dir, f'tsp{node_num}_train.pkl'), 'rb') as f:
    train_data = pickle.load(f)
    print(f'train_data loaded, length: {len(train_data)}')
node_embeddings = []
for i in range(len(train_data)):
    positions = train_data[i]['prefix']['position'].reshape(-1, 2)
    matrix = np.linalg.norm(positions[:,None,:] - positions[None,:,:], axis=-1)
    node_embeddings.append(matrix)
for i in range(len(train_data)):
    train_data[i]['prefix']['node_embedding'] = node_embeddings[i]
    train_data[i]['prefix_masks']['node_embedding'] = train_data[i]['prefix_masks']['position'].repeat(10, axis=-1)
    total_pos = train_data[i]['prefix']['position'].reshape(1, 20, 2).repeat(19, axis=0)
    matches = (total_pos == train_data[i]['observations']['current_position'][:,None,:])[:,:,0]
    matches = np.where(matches)[1]
    train_data[i]['observations']['current_embedding'] = train_data[i]['prefix']['node_embedding'][matches]
    train_data[i]['prefix'].pop('position')
    train_data[i]['prefix_masks'].pop('position')
    train_data[i]['observations'].pop('current_position')
    train_data[i]['prefix']['node_embedding'] = node_embeddings[i].reshape(-1)


with open(os.path.join(to_data_dir, f'tsp{node_num}_train.pkl'), 'wb') as f:
    pickle.dump(train_data, f)
    print(f'train_data saved')

with open(os.path.join(to_data_dir, f'tsp{node_num}_problem.pkl'), 'wb') as f:
    pickle.dump(problem_data, f)
    print(f'problem saved')

with open(os.path.join(to_data_dir, f'tsp{node_num}_prompt.pkl'), 'wb') as f:
    pickle.dump(prompt_data, f)
    print(f'prompt_data saved')


#####################################################
### generate train_problem.pkl
problem_list = []
answer_list = []
if os.path.exists(os.path.join('/data1/XXX/gato-revise/toy-gato-data/data/used/ATSP_V1/', f'tsp{node_num}_train.pkl')):
    print(f'generating train_problem.pkl from train.pkl')
    with open(os.path.join('/data1/XXX/gato-revise/toy-gato-data/data/used/ATSP_V2/', f'tsp{node_num}_train.pkl'), 'rb') as f:
        train_data = pickle.load(f)
        for i in range(0, 500):
            problem_list.append({'position':train_data[i]['prefix']['position'], 'node_embedding':train_data[i]['prefix']['node_embedding']})
            answer_list.append(list(train_data[i]['actions']))
problem_data = DataProblem(prefix_list=None, problem_list=problem_list, answer_list=answer_list)


with open(os.path.join(to_data_dir, f'tsp{node_num}_train_problem.pkl'), 'wb') as f:
    pickle.dump(problem_data, f)
    print(f'train_problem data saved')
exit()