import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

from environment.used.Env_tsp_v2 import TSP_V2
from environment.used.BaseEnv_COP import DataProblem
from utils.utils import create_folder_overwrite_if_exist, set_seed, get_file_path, load_existing_data, save_eval_problems, init_problems_for_multi_process, save_trajectories
from gym.utils.env_checker import check_env
import numpy as np
import multiprocessing
import random
import math
from multiprocessing import Manager, Lock

def gen_an_episode(env, lock):
    # 重置环境, 生成随机的 tsp 问题，LKH 法求解
    # LKH 求解部分进程不安全，手动上锁
    lock.acquire()
    observation, info = env.reset() 
    lock.release()
    solution = env.real_answer 
    problem_info = observation
    answer_info = env.real_answer

    # 生成 MDP 轨迹
    obss, acts, rewards = [observation, ], [], []
    for action in solution[1:-1]:
        observation, reward, terminated, truncated, info = env.step(action)
        acts.append(action)
        rewards.append(reward)
        obss.append(observation)
        assert not (terminated or truncated)
    action = solution[-1]
    observation, reward, terminated, truncated, info = env.step(action)  
    assert reward == 1
    assert terminated and not truncated
    acts.append(action)
    rewards.append(reward)
    assert len(obss) == len(acts) == len(rewards)

    # 处理成 d4rl 格式保存
    obss_positions = np.vstack([obs['position'] for obs in obss])
    obss_visiteds = np.vstack([obs['visited'] for obs in obss])
    current_position = np.array([obs['current_position'] for obs in obss])
    episode = {
        'prefix': None,
        'observations': {
            'position': obss_positions,                     # (time_steps, 2*num_nodes)
            'visited': obss_visiteds,                       # (time_steps, num_nodes)
            'current_position': current_position            # (time_steps, 2)
        },
        'actions': np.array(acts).astype(np.int32),         # (time_steps, )
        'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
        'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
    }

    return episode, problem_info, answer_info

def gen_trajectories(
    process_epi_num_to_go, seed, num_nodes, 
    lock, episodes, problems
):
    env = TSP_V2(num_nodes=num_nodes) 
    check_env(env)
    env.action_space.seed(seed)
    env.reset(seed=seed)
    for _ in range(process_epi_num_to_go):
        episode, problem_info, answer_info = gen_an_episode(env, lock)
        episodes.append(episode)
        problems.problem_list.append(problem_info)
        problems.answer_list.append(answer_info)

def gen_eval_problems(
    process_prob_num_to_go, seed, num_nodes, 
    lock, problems
):
    env = TSP_V2(num_nodes=num_nodes) 
    check_env(env)
    env.action_space.seed(seed)
    env.reset(seed=seed)
    for _ in range(process_prob_num_to_go):
        # LKH 求解部分进程不安全，手动上锁
        lock.acquire()
        observation, _ = env.reset() 
        lock.release()

        problems.problem_list.append(observation)
        problems.answer_list.append(env.real_answer )

if __name__ == "__main__":
    # 设置参数
    num_nodes = 20
    file_name = f'tsp{num_nodes}'
    data_name = 'TSP_V2'
    prob_num = 10000                     # 用于评估的问题数量
    epi_num = 20000                    # 要生成的总轨迹数量
    save_interval = 500                 # 轨迹保存间隔
    save_interval_prob = 500            # 评估问题保存间隔
    processes_num = os.cpu_count()     # 生成进程数量
    #processes_num = 7                   # 生成进程数量
    overwrite = False                   # 是否覆盖已有的数据
    assert epi_num % 10 == 0
    assert save_interval % 10 == 0
    train_epi_num = int(epi_num * 0.9)  # 训练集轨迹数，其余作为提示序列轨迹
    set_seed(random.randint(1e5,2e5))

    # 保存求解器可能产生的中间文件
    create_folder_overwrite_if_exist(f'{base_path}/temp') 

    # 用于保存数据的文件夹 & 文件路径
    train_path, prompt_path, problem_path, train_problem_path = get_file_path(data_name, file_name)
    
    # 加载已存在的数据
    (
        ex_train_epi, ex_prompt_epi, 
        ex_train_problem, ex_eval_problem, 
        process_epi_num_to_go, 
        process_prob_num_to_go
    ) = load_existing_data(
        train_path, prompt_path, problem_path, train_problem_path,
        prob_num, epi_num, overwrite, processes_num
    )

    # 多进程保存评估问题
    if len(ex_eval_problem.answer_list) < prob_num:
        with Manager() as manager:
            problems = init_problems_for_multi_process(manager, ex_eval_problem)
            
            # 启动问题保存进程
            save_process = multiprocessing.Process(
                target=save_eval_problems, 
                args=(prob_num, save_interval_prob, problem_path, problems)
            )
            save_process.start()

            # 启动问题生成进程
            processes = []
            lock = Lock()
            for process_id in range(processes_num):
                process = multiprocessing.Process(
                    target=gen_eval_problems, 
                    args=(
                        process_prob_num_to_go, random.randint(0,1e5), num_nodes, 
                        lock, problems
                    )
                )
                processes.append(process)
                process.start()

            # join 所有进程
            for process in processes:
                process.join()
            save_process.join()

    # 多进程保存训练轨迹、提示轨迹和训练问题
    with Manager() as manager:
        # 各个进程共享的轨迹和问题对象
        episodes = manager.list()
        episodes.extend(ex_train_epi) 
        episodes.extend(ex_prompt_epi) 
        problems = init_problems_for_multi_process(manager, ex_train_problem)

        # 启动数据保存进程
        save_process = multiprocessing.Process(
            target=save_trajectories, 
            args=(
                epi_num, save_interval, 
                train_path, prompt_path, train_problem_path, 
                episodes, problems
            )
        )
        save_process.start()

        # 启动数据生成进程
        processes = []
        lock = Lock()
        for process_id in range(processes_num):
            process = multiprocessing.Process(
                target=gen_trajectories, 
                args=(
                    process_epi_num_to_go, random.randint(0,1e5), num_nodes, 
                    lock, episodes, problems
                )
            )
            processes.append(process)
            process.start()

        # join 所有进程
        for process in processes:
            process.join()
        save_process.join()