import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

from environment.used.Env_cvrp_v1 import CVRP_V1
from environment.used.BaseEnv_COP import DataProblem
from utils.utils import create_folder_overwrite_if_exist, init_problems_for_multi_process, set_seed, save_eval_problems, get_file_path, load_existing_data, save_trajectories
from gym.utils.env_checker import check_env
import numpy as np
import multiprocessing
import random
from multiprocessing import Manager

def gen_an_episode(env):
    # 重置环境, 生成随机的 cvrp 问题，LKH 法求解
    # LKH 求解部分进程不安全，手动上锁
    observation, info = env.reset() 
    solution = [0] + env.real_answer
    problem_info = observation
    answer_info = env.real_answer

    # 生成 MDP 轨迹
    obss, acts, rewards = [observation, ], [], []
    for action in solution[1:-1]:
        observation, reward, terminated, truncated, info = env.step(action)
        acts.append(action)
        rewards.append(reward)
        obss.append(observation)
        assert not (terminated or truncated)
    action = solution[-1]
    observation, reward, terminated, truncated, info = env.step(action)  
    assert reward == 1
    assert terminated and not truncated
    acts.append(action)
    rewards.append(reward)
    assert len(obss) == len(acts) == len(rewards)

    # 处理成 d4rl 格式保存
    obss_pos_depot = np.array([obs['pos_depot'] for obs in obss])
    obss_pos_node = np.vstack([obs['pos_node'] for obs in obss])
    obss_demand = np.vstack([obs['demand'] for obs in obss])
    obss_visited = np.vstack([obs['visited'] for obs in obss])
    obss_capacity = np.array([obs['capacity'] for obs in obss])
    current_position = np.array([obs['current_position'] for obs in obss])
    
    episode = {
        'prefix': None,
        'observations': {
            'pos_depot': obss_pos_depot,                    # (time_steps, 2)
            'pos_node': obss_pos_node,                      # (time_steps, 2*num_nodes)
            'demand': obss_demand,                          # (time_steps, num_nodes)
            'capacity': obss_capacity,                      # (time_steps, )
            'current_position': current_position,           # (time_steps, 2)
            'visited': obss_visited                         # (time_steps, num_nodes)
        },
        'actions': np.array(acts).astype(np.int32),         # (time_steps, )
        'rewards': np.array(rewards).astype(np.float32),    # (time_steps, )
        'terminals': np.array([False] * (len(rewards)-1) + [True], dtype=bool)  # 'terminals' 字段只是模仿 d4rl 的数据形式，当前没有作用
    }

    return episode, problem_info, answer_info

def gen_trajectories(
    process_epi_num_to_go, seed, node_num, 
    episodes, problems, lock
):
    env = CVRP_V1(node_num=node_num) 
    #check_env(env)
    env.action_space.seed(seed)
    env.reset(seed=seed)
    #set_seed(seed)
    for _ in range(process_epi_num_to_go):
        episode, problem_info, answer_info = gen_an_episode(env)
        episodes.append(episode)
        lock.acquire()  # 写入 problems 时上锁，以免 problem_list 和 answer_list 中相同索引位置的内容不匹配
        problems.problem_list.append(problem_info)
        problems.answer_list.append(answer_info)
        lock.release()

def gen_eval_problems(process_prob_num_to_go, seed, node_num, problems, lock):
    env = CVRP_V1(node_num=node_num) 
    #check_env(env)
    env.action_space.seed(seed)
    env.reset(seed=seed)
    for _ in range(process_prob_num_to_go):
        observation, _ = env.reset() 
        lock.acquire()  # 写入 problems 时上锁，以免 problem_list 和 answer_list 中相同索引位置的内容不匹配
        problems.problem_list.append(observation)
        problems.answer_list.append(env.real_answer)
        lock.release()
        

if __name__ == "__main__":
    # 设置参数
    node_num = 50
    file_name = f'cvrp{node_num}'
    data_name = 'CVRP_V1'
    prob_num = 10000                      # 用于评估的问题数量
    epi_num = 0                       # 要生成的总轨迹数量
    save_interval = 250                  # 轨迹保存间隔
    save_interval_prob = 250             # 评估问题保存间隔
    processes_num = os.cpu_count()      # 生成进程数量
    #processes_num = 1
    overwrite = False                   # 是否覆盖已有的数据
    assert epi_num % 10 == 0
    assert save_interval % 10 == 0
    train_epi_num = int(epi_num * 0.9)  # 训练集轨迹数，其余作为提示序列轨迹
    set_seed(random.randint(1e5,2e5))

    # 保存求解器可能产生的中间文件
    create_folder_overwrite_if_exist(f'{base_path}/temp')

    # 用于保存数据的文件夹 & 文件路径
    train_path, prompt_path, problem_path, train_problem_path = get_file_path(data_name, file_name)

    # 加载已存在的数据
    (
        ex_train_epi, ex_prompt_epi, 
        ex_train_problem, ex_eval_problem, 
        process_epi_num_to_go, 
        process_prob_num_to_go
    ) = load_existing_data(
        train_path, prompt_path, problem_path, train_problem_path,
        prob_num, epi_num, overwrite, processes_num
    )
    
    # 多进程保存评估问题
    if len(ex_eval_problem.answer_list) < prob_num:
        with Manager() as manager:
            problems = init_problems_for_multi_process(manager, ex_eval_problem)
            
            # 启动问题保存进程
            save_process = multiprocessing.Process(
                target=save_eval_problems, 
                args=(prob_num, save_interval_prob, problem_path, problems)
            )
            save_process.start()

            # 启动问题生成进程
            processes = []
            lock = multiprocessing.Lock()
            for process_id in range(processes_num):
                process = multiprocessing.Process(
                    target=gen_eval_problems, 
                    args=(
                        process_prob_num_to_go, random.randint(0,1e5), 
                        node_num, problems, lock
                    )
                )
                processes.append(process)
                process.start()

            # join 所有进程
            for process in processes:
                process.join()
            save_process.join()

    # 多进程保存训练轨迹、提示轨迹和训练问题
    with Manager() as manager:
        # 各个进程共享的轨迹和问题对象
        episodes = manager.list()
        episodes.extend(ex_train_epi) 
        episodes.extend(ex_prompt_epi) 
        problems = init_problems_for_multi_process(manager, ex_train_problem)

        # 启动数据保存进程
        save_process = multiprocessing.Process(
            target=save_trajectories, 
            args=(
                epi_num, save_interval, 
                train_path, prompt_path, train_problem_path, 
                episodes, problems
            )
        )
        save_process.start()

        # 启动数据生成进程
        processes = []
        lock = multiprocessing.Lock()
        for process_id in range(processes_num):
            process = multiprocessing.Process(
                target=gen_trajectories, 
                args=(
                    process_epi_num_to_go, random.randint(0,1e5), node_num, 
                    episodes, problems, lock
                )
            )
            processes.append(process)
            process.start()

        # join 所有进程
        for process in processes:
            process.join()
        save_process.join()