import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import numpy as np
import pickle
from tqdm import tqdm
from environment.used.Env_op_v2 import OP_V2
from environment.used.BaseEnv_COP import DataProblem, RawData
from utils.utils import create_folder_if_not_exist, split_rawdata, merge_rawdata
from utils.COP_slover import calc_op_total
import multiprocessing as mp

def load_existing_raw_data(node_num):
    train_path = f'{base_path}/data/used/_raw/op/op{node_num}_train.pkl'
    test_path = f'{base_path}/data/used/_raw/op/op{node_num}_problem.pkl'
    prompt_path = f'{base_path}/data/used/_raw/op/op{node_num}_prompt.pkl'
    
    train_data = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    test_data = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    prompt_data = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    
    if os.path.isfile(train_path):
        with open(train_path, 'rb') as f:
            train_data = pickle.load(f)  
    if os.path.isfile(test_path):
        with open(test_path, 'rb') as f:
            test_data = pickle.load(f)  
    if os.path.isfile(prompt_path):
        with open(prompt_path, 'rb') as f:
            prompt_data = pickle.load(f)  

    print('-'*50)
    for rawdata, datatype in zip([train_data, test_data, prompt_data], ['train', 'test', 'prompt']):
        print(f'[{len(rawdata.answer_list)}] raw data for [{datatype}] loaded, ave cost [{np.mean(rawdata.cost_list)}], [{len(set(rawdata.seed_list))}] seeds used')
    print('-'*50)
    
    return train_data, test_data, prompt_data

def raw_data_to_problem(raw_data:RawData, num:int=0):
    num = len(raw_data.answer_list) if num == 0 else min(num, len(raw_data.answer_list))
    problems = DataProblem(problem_list=raw_data.problem_list[:num], answer_list=[[a+1 for a in answer] for answer in raw_data.answer_list[:num]])
    with tqdm(total=num, desc=f'Checking') as pbar:
        for problem, answer, cost in zip(problems.problem_list, problems.answer_list, raw_data.cost_list[:num]):
            assert abs(-calc_op_total(problem['prize'], np.array(answer)-1) - cost) < 1e-5
            pbar.update()
    return problems

def raw_data_to_traj(info):
    global train_data
    global prompt_data
    global test_data
    i, traj_type = info
    if traj_type == 'train':
        dataset = train_data
    elif traj_type == 'prompt':
        dataset = prompt_data
    elif traj_type == 'test':
        dataset = test_data
    else:
        raise NotImplementedError

    try:
        problem, answer, cost = dataset.problem_list[i], dataset.answer_list[i], dataset.cost_list[i] 
        answer = np.array(answer) + 1
        cost = -cost
        assert abs(calc_op_total(problem['prize'], answer-1) - cost) < 1e-5

        # 重置环境
        env = OP_V2(node_num=problem['prize'].shape[0])
        observation, info = env.reset(options={
            'problem_info': (None, problem, answer),
            'problem_obj':(cost, 1),
            'use_default_policy_obj': False
        })
        assert (answer == env.real_answer).all()

        # 生成 MDP 轨迹
        # 注意 random obj 未知，生成轨迹中的 DB1 reward 不可靠，因此只保留 AM reward
        obss, acts, rewards = [observation, ], [], []
        for action in answer[:-1]:
            observation, reward, terminated, truncated, info = env.step(action)
            acts.append(action)
            rewards.append(reward['AM'])
            obss.append(observation)
            assert not (terminated or truncated)
        action = answer[-1]
        observation, reward, terminated, truncated, info = env.step(action)
        assert abs(reward['AM'] - 1) <= 1e-4 and abs(reward['DB1'] - 1) <= 1e-4    # 由于直接用 real_answer 作为 model_answer, 无论 DB1 reward 是多少，reward['DB1'] == 1
        assert terminated and not truncated
        acts.append(action)
        rewards.append(1)
        assert len(obss) == len(acts) == len(rewards)
    except Exception as e:
        return None

    # 处理成 d4rl 格式保存
    obss_pos_depot = np.array([obs['pos_depot'] for obs in obss])
    obss_pos_node = np.vstack([obs['pos_node'] for obs in obss])
    obss_prize = np.vstack([obs['prize'] for obs in obss])
    obss_visited = np.vstack([obs['visited'] for obs in obss])
    obss_length = np.array([obs['length'] for obs in obss])
    current_position = np.array([obs['current_position'] for obs in obss])
    episode = {
        'prefix': None,
        'observations': {
            'pos_depot': obss_pos_depot,                    # (time_steps, 2)
            'pos_node': obss_pos_node,                      # (time_steps, 2*node_num)
            'prize': obss_prize,                            # (time_steps, node_num)
            'length': obss_length,                          # (time_steps, )
            'current_position': current_position,           # (time_steps, 2)
            'visited': obss_visited                         # (time_steps, node_num)
        },
        'actions': np.array(acts).astype(np.int32),         # (time_steps, )
        'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
        'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
    }
    return episode

def raw_data_to_traj_multiprocessing(raw_data:RawData, process_num:int=10, data_num:int=0, data_type:str='train'):
    data_num = len(raw_data.answer_list) if data_num == 0 else min(data_num, len(raw_data.answer_list))
    with mp.Pool(processes=process_num) as pool:
        results = tqdm(
            pool.imap_unordered(raw_data_to_traj, [(i, data_type) for i in range(data_num)]),
            total=data_num,
        )  # 'total' is redundant here but can be useful when the size of the iterable is unobvious
        results = list(results)
    return results

if __name__ == "__main__":
    # 构造环境
    node_num = 20
    dataset_name = 'OP_V2'
    gen_test_episodes = False
    get_trian_problem = False
    worker_num = 100

    # 用于保存数据的文件夹
    create_folder_if_not_exist(f'{base_path}/data/used/{dataset_name}')

    # 加载三类原始数
    train_data, test_data, prompt_data = load_existing_raw_data(node_num)
    
    # 把prompt里多出的data补充到problem中
    remain_prompt_data = split_rawdata(prompt_data, 0, 5000)
    problem_prompt_data = split_rawdata(prompt_data, 5000, None)
    test_data = merge_rawdata(problem_prompt_data, test_data)
    prompt_data = remain_prompt_data
    print(len(train_data.answer_list), len(test_data.answer_list), len(prompt_data.answer_list))

    # 构造训练数据对应的问题，这些问题上的测试性能应该接近 100%
    if get_trian_problem:
        train_problem = raw_data_to_problem(train_data)
        with open(f'{base_path}/data/used/{dataset_name}/op{node_num}_train_problem.pkl', 'wb') as f:
            pickle.dump(train_problem, f)    
            print(f'train_problems saved')  

    # 构造训练中未见的测试问题
    test_problem = raw_data_to_problem(test_data)
    with open(f'{base_path}/data/used/{dataset_name}/op{node_num}_problem.pkl', 'wb') as f:
        pickle.dump(test_problem, f)    
        print(f'test_problems saved')
        
    # 构造提示轨迹
    prompt_episodes = raw_data_to_traj_multiprocessing(prompt_data, worker_num, data_type='prompt')
    prompt_episodes = [epi for epi in prompt_episodes if epi is not None]
    with open(f'{base_path}/data/used/{dataset_name}/op{node_num}_prompt.pkl', 'wb') as f:
        pickle.dump(prompt_episodes, f)    
        print(f'{len(prompt_episodes)} prompt_traj saved')

    # 构造训练轨迹
    train_episodes = raw_data_to_traj_multiprocessing(train_data, worker_num, data_type='train')
    train_episodes = [epi for epi in train_episodes if epi is not None]
    with open(f'{base_path}/data/used/{dataset_name}/op{node_num}_train.pkl', 'wb') as f:
        pickle.dump(train_episodes, f)    
        print(f'{len(train_episodes)} train_traj saved')

    # （可选）训练中未见测试问题对应的最优轨迹，这些数据上的损失反映拟合的泛化质量
    if gen_test_episodes:
        test_episodes = raw_data_to_traj_multiprocessing(test_data, node_num, data_type='test')
        with open(f'{base_path}/data/used/{dataset_name}/op{node_num}_test.pkl', 'wb') as f:
            pickle.dump(test_episodes, f)    
            print(f'test_traj saved')