import os
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))

import sys
matnet_base_path = os.path.abspath('/data1/XXX/MatNet/FFSP/FFSP_MatNet')
sys.path.append(matnet_base_path)

import shutil
import torch
import numpy as np
import random
import json
import argparse
from model import Gato
import fsspec
import time
import pickle
import math
import numpy as np
from typing import List
from environment.used.BaseEnv_COP import DataProblem, RawData
from utils.COP_slover import calc_vrp_distance, calc_tsp_distance, calc_op_total, calc_pctsp_cost, knapsack_dp
from matnet_test import eval_main

COP_FAILED_RWD = -1000
DEFAULT_RND_OBJ_VALUE = {
    'BP':    {20: 38.2366},
    'PCTSP': {10: 1, 20: 9.178939850997924},
    'OP':    {10: 1, 20: 1.9316279964268208},
    'CVRP':  {10: 1, 20: 13.234259510421753},
    'TSP':   {10: 1, 20: 10.432572023057938, 100: 52.16533660888672, 200: 104.30537414550781},
    'SPCTSP': {10: 1, 20: 9.614649021140231},
    'FFSP':  {20: 10.0},
    'MIS': {20:0}
}

def create_folder_if_not_exist(floder_path):
    os.makedirs(floder_path, exist_ok=True)

def create_folder_overwrite_if_exist(floder_path):
    if os.path.exists(floder_path):
        shutil.rmtree(floder_path)    
    create_folder_if_not_exist(floder_path)

def create_file_if_not_exist(file_path):
    try:
        with open(file_path, 'a') as file:
            pass
    except FileNotFoundError:
        floder_path = file_path[:file_path.rfind('/')]
        create_folder_if_not_exist(floder_path)
        with open(file_path, 'w') as file:
            pass
        time.sleep(1)

def set_seed(seed, envs=None):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

    if envs is not None:
        for env in envs:
            env.reset(seed=seed)
    '''
    # there are some operation do not have a deterministic implementation, 
    # so actually the experiment result can't be complete reproduction
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)
    '''

def moving_average(a, window_size):
    ''' 生成序列 a 的滑动平均序列 '''
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

def top_k_logits(logits, k):
    '''保留 logits 最后一个维度中 topk 最大值, 其他换成 -inf'''
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

def str2bool(x):
    assert x == "True" or x == "False"
    return True if x == "True" else False

def load_model(config_path, ckpt_path=None, snapshot_path=None):
    assert (ckpt_path is None) ^ (snapshot_path is None)
    with open(config_path, 'r') as f:
        config_dict = json.load(f)
        args = argparse.Namespace(**config_dict)
    gato = Gato(args)

    current_epoch = 0
    if ckpt_path is not None:
        gato.load_state_dict(torch.load(ckpt_path))
        current_epoch = int(ckpt_path[ckpt_path.rfind('epoch')+5:-3])
    else:
        snapshot = fsspec.open(snapshot_path)   # fsspec 为各种后端存储系统提供统一的 Python 接口，可以用相同的语法打开本地、AWS S3 和 GCS 等各种云存储平台的文件
        with snapshot as f:
            snapshot_data = torch.load(f, map_location="cpu")    
            gato.load_state_dict(snapshot_data["model_state"])
            current_epoch = snapshot_data['finished_epoch']
            print(f"The loaded snapshot was saved at Epoch {current_epoch}")
    return args, gato, current_epoch

def split_dataproblem(problem:DataProblem, from_idx:int, to_idx:int):
    assert problem.answer_list is not None
    return DataProblem(
        prefix_list = None if problem.prefix_list is None else problem.prefix_list[from_idx:to_idx],
        problem_list = None if problem.problem_list is None else problem.problem_list[from_idx:to_idx],
        answer_list = problem.answer_list[from_idx:to_idx]
    )

def split_rawdata(data:RawData, from_idx:int, to_idx:int):
    assert len(data.problem_list) == len(data.answer_list) == len(data.cost_list)
    return RawData(
        seed_list = data.seed_list,
        problem_list = data.problem_list[from_idx:to_idx],
        answer_list = data.answer_list[from_idx:to_idx],
        cost_list = data.cost_list[from_idx:to_idx],
    )

def merge_rawdata(source_data:RawData, target_data:RawData):
    assert len(source_data.problem_list) == len(source_data.answer_list) == len(source_data.cost_list)
    assert len(target_data.problem_list) == len(target_data.answer_list) == len(target_data.cost_list)
    return RawData(
        seed_list = target_data.seed_list,
        problem_list = target_data.problem_list + source_data.problem_list,
        answer_list = target_data.answer_list + source_data.answer_list,
        cost_list = target_data.cost_list + source_data.cost_list,
    )

# ------------------------ traj data generation ------------------------
def get_file_path(data_name, file_name):
    data_path = f'{base_path}/data/used/{data_name}'
    create_folder_if_not_exist(data_path)
    train_path = f'{data_path}/{file_name}_train.pkl'
    prompt_path = f'{data_path}/{file_name}_prompt.pkl'
    problem_path = f'{data_path}/{file_name}_problem.pkl'
    train_problem_path = f'{data_path}/{file_name}_train_problem.pkl'
    return train_path, prompt_path, problem_path, train_problem_path

def load_data(data_name:str, data_file_name:str, num_problem:int=0, num_episode:int=0):
    ''' 加载已生成的 .pkl 轨迹或问题数据文件 '''
    with open(f'{base_path}/data/used/{data_name}/{data_file_name}.pkl', 'rb') as f:
        data = pickle.load(f)  
    
    if data_file_name.endswith('problem'):
        return data if num_problem == 0 else split_dataproblem(data, 0, num_problem)
    return data if num_episode == 0 else data[:num_episode]

def init_problems_for_multi_process(manager, existing_problems:DataProblem):
    problems = manager.Namespace()
    problems.prefix_list = None if existing_problems.prefix_list is None else manager.list()
    problems.problem_list = None if existing_problems.problem_list is None else manager.list()
    problems.answer_list = manager.list()

    problems.answer_list.extend(existing_problems.answer_list)
    if existing_problems.prefix_list is not None:
        problems.prefix_list.extend(existing_problems.prefix_list)
    if existing_problems.problem_list is not None:
        problems.problem_list.extend(existing_problems.problem_list)

    return problems

def get_best_obj_value(env_name, problems:DataProblem):
    obj_list = []
    if env_name in ['Env_BP_V1', 'Env_BP_V2']:
        for problem, _ in zip(problems.problem_list, problems.answer_list):
            value, _ = knapsack_dp(problem['capacity_left'].item(), problem['item_volumes'], problem['item_values'])
            obj_list.append(value)
    elif env_name in ['Env_ATSP_V1', 'Env_ATSP_V2', 'Env_TSP_V1', 'Env_TSP_V2', 'Env_TSP_V3', 'Env_TSP_V4']:
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            position = problem['position'].reshape((-1,2))
            obj_list.append(calc_tsp_distance(position, answer))
    elif env_name.startswith('Env_CVRP'):
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            pos = np.vstack((problem['pos_depot'][None,:], problem['pos_node'].reshape((-1, 2))))
            obj_list.append(calc_vrp_distance(pos, answer))
    elif env_name == 'Env_OP_V1':
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            obj_list.append(calc_op_total(problem['prize'], answer))
    elif env_name in ['Env_OP_V2', 'Env_OP_V3', 'Env_OP_V4']:
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            obj_list.append(calc_op_total(problem['prize'], np.array(answer)-1))
    elif env_name.startswith('Env_PCTSP'):
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            pos = np.vstack((problem['pos_depot'][None,:], problem['pos_node'].reshape((-1, 2))))
            obj_list.append(calc_pctsp_cost(pos, problem['penalty'], problem['prize'], answer[:-1]))
    elif env_name.startswith('Env_SPCTSP'):
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            pos = np.vstack((problem['pos_depot'][None,:], problem['pos_node'].reshape((-1, 2))))
            obj_list.append(calc_pctsp_cost(pos, problem['penalty'], problem['stoc_prize'], answer[:-1]))
    elif env_name.startswith('Env_FFSP'):
        input_data = [x['durations'] for x in problems.problem_list]
        _, returns = eval_main(np.array(input_data)[:, :-1, :])
        obj_list = list(returns)
    elif env_name.startswith('Env_MIS'):
        for problem, answer in zip(problems.problem_list, problems.answer_list):
            obj_list.append(10.45645)
    else:
        raise NotImplementedError

    assert None not in obj_list
    return np.array(obj_list)

def load_data_and_check_quantity(file_names:List, env_name:str, num_problem:int=0, num_episode:int=0, check_ave_obj:bool=False):
    file_data = {}
    print('-'*60)
    for file_name in file_names:
        data_name = file_name[:-4]
        records = load_data(env_name, data_name, num_problem, num_episode)  
        if data_name.endswith('problem'):
            if check_ave_obj:
                obj_array = get_best_obj_value(env_name=f'Env_{env_name}', problems=records)
                print(f'{data_name:20}\t{len(records.answer_list)}\t{obj_array.mean()}')
            else:
                print(f'{data_name:20}\t{len(records.answer_list)}')
            
            obss_range = {k: [float('inf'), -float('inf')] for k in records.problem_list[0]}
            for problem in records.problem_list:
                for k, v in problem.items():
                    if v.min() < obss_range[k][0]:
                        obss_range[k][0] = v.min()
                    if v.max() > obss_range[k][1]:
                        obss_range[k][1] = v.max()
            for k, v in obss_range.items():
                print(f'\t{k:18}\tin range {v}')
        else:
            print(f'{data_name:20}\t{len(records)}')
            obss_range = {k: [float('inf'), -float('inf')] for k in records[0]['observations']}
            for data in records:
                for k, v in data['observations'].items():
                    if v.min() < obss_range[k][0]:
                        obss_range[k][0] = v.min()
                    if v.max() > obss_range[k][1]:
                        obss_range[k][1] = v.max()
            for k, v in obss_range.items():
                print(f'\t{k:18}\tin range {v}')
        file_data[data_name] = records
        print()
    print('-'*60)
    return file_data

def load_existing_data(
    train_path, prompt_path, problem_path, train_problem_path,
    prob_num, epi_num, overwrite=False, processes_num=1
):
    ''' 加载已经生成的问题、轨迹数据，用于从断点重启数据生成流程 '''
    ex_train_epi, ex_prompt_epi = [], []
    ex_eval_prob = DataProblem(problem_list = [], answer_list = [])
    ex_train_prob = DataProblem(problem_list = [], answer_list = [])
    if not overwrite:
        if os.path.isfile(train_path):
            with open(train_path, 'rb') as f:
                ex_train_epi = pickle.load(f)  
        if os.path.isfile(prompt_path):
            with open(prompt_path, 'rb') as f:
                ex_prompt_epi = pickle.load(f)  
        if os.path.isfile(train_problem_path):
            with open(train_problem_path, 'rb') as f:
                ex_train_prob = pickle.load(f)  
        if os.path.isfile(problem_path):
            with open(problem_path, 'rb') as f:
                ex_eval_prob = pickle.load(f)  

    ex_train_epi_num = len(ex_train_epi)
    ex_prompt_epi_num = len(ex_prompt_epi)
    ex_epi_num = ex_train_epi_num + ex_prompt_epi_num
    ex_eval_prob_num = len(ex_eval_prob.answer_list)
    ex_train_prob_num = len(ex_train_prob.answer_list)
    assert ex_train_epi_num == ex_train_prob_num
    process_epi_num_to_go = math.ceil((epi_num - ex_epi_num) / processes_num)
    process_prob_num_to_go = math.ceil((prob_num - ex_eval_prob_num) / processes_num)

    # 打印数据集信息
    print('-'*50)
    print(f'There are [{ex_eval_prob_num}] problem generated, [{prob_num-ex_eval_prob_num}] to go')
    print(f'There are [{ex_epi_num}] episode generated, [{epi_num-ex_epi_num}] to go')
    print(f'    Training episode:   {ex_train_epi_num}')
    print(f'    Prompt episode:     {ex_prompt_epi_num}')
    print('-'*50)
    
    return ex_train_epi, ex_prompt_epi, ex_train_prob, ex_eval_prob, process_epi_num_to_go, process_prob_num_to_go

def save_trajectories(
    epi_num, save_interval, train_path, prompt_path, train_problem_path,
    episodes, problems
):
    ''' 保存训练轨迹、训练轨迹问题和提示轨迹，这是数据生成过程中保存进程的目标方法 '''
    # 保存轨迹
    def _save():
        # 保存训练轨迹和提示轨迹
        save_train_num = int(save_num * 0.9)
        with open(train_path, 'wb') as f:
            pickle.dump(episodes[:save_train_num], f)
        with open(prompt_path, 'wb') as f:
            pickle.dump(episodes[save_train_num: save_num], f)
        
        # 保存训练问题
        assert problems.answer_list is not None
        save_problem = split_dataproblem(problems, 0, save_train_num)
        with open(train_problem_path, 'wb') as f:
            pickle.dump(save_problem, f)

    started_time = time.time()
    save_num = len(episodes)
    last_printed = 0
    while save_num < epi_num:
        if save_num != last_printed and save_num != 0 and save_num % save_interval == 0:
            print(f'[{round(time.time()-started_time,2)} s]: {save_num} episodes is saving', end="", flush=True)
            _save()
            print(f'\r[{round(time.time()-started_time,2)} s]: {save_num} episodes saved!'+' '*10, end="", flush=True)
            print()
            last_printed = save_num
        save_num = len(episodes)
        time.sleep(0.01)

    _save()
    print(f'[{round(time.time()-started_time,2)} s]: {save_num} episodes saved, saving process exited')

def save_eval_problems(problem_num, save_interval, problem_path, problems):
    ''' 保存评估问题，这是数据生成过程中问题保存进程的目标方法 '''
    # 保存轨迹
    def _save():
        assert problems.answer_list is not None
        save_problem = split_dataproblem(problems, 0, save_num)
        with open(problem_path, 'wb') as f:
            pickle.dump(save_problem, f)

    started_time = time.time()
    save_num = len(problems.answer_list)
    last_printed = 0
    while save_num < problem_num:
        if save_num != last_printed and save_num != 0 and save_num % save_interval == 0:
            print(f'[{round(time.time()-started_time,2)} s]: {save_num} problems is saving', end="", flush=True)
            _save()
            print(f'\r[{round(time.time()-started_time,2)} s]: {save_num} problems saved!'+' '*10, end="", flush=True)
            print()
            last_printed = save_num
        save_num = len(problems.answer_list)
        time.sleep(0.01)

    save_num = problem_num
    _save()
    print(f'[{round(time.time()-started_time,2)} s]: {save_num} problems saved, saving process exited\n\n')

# ------------------------ raw data generation ------------------------
def split_raw_dataset(dataset:RawData, from_idx:int, to_idx:int):
    assert dataset.answer_list is not None
    return RawData(
        seed_list = dataset.seed_list[:],
        problem_list = dataset.problem_list[from_idx:to_idx],
        answer_list = dataset.answer_list[from_idx:to_idx],
        cost_list = dataset.cost_list[from_idx:to_idx],
    )

def load_existing_raw_data(dataset_path, overwrite=False):
    ''' 加载已经生成的原始数据，用于从断点重启数据生成流程 '''
    ex_data = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    if not overwrite:
        if os.path.isfile(dataset_path) and os.path.exists(dataset_path):
            with open(dataset_path, 'rb') as f:
                ex_data = pickle.load(f)  
    return ex_data

def merge_to_dataset(from_dataset_path_list, to_dataset_path, backup_path):
    ''' 将多个RawData类型子数据集的数据合并到目标数据集  '''
    # 加载可能存在的目标数据集
    target_data = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    if os.path.isfile(to_dataset_path) and os.path.exists(to_dataset_path):
        with open(to_dataset_path, 'rb') as f:
            target_data = pickle.load(f)  
    
    # 合并所有子数据集
    for sub_dataset_path in from_dataset_path_list:
        assert os.path.exists(sub_dataset_path)
        with open(sub_dataset_path, 'rb') as f:
            sub_dataset = pickle.load(f)
        assert all(elem not in target_data.seed_list for elem in sub_dataset.seed_list), "检查到重复的随机种子"
        target_data.seed_list.extend(sub_dataset.seed_list)
        target_data.problem_list.extend(sub_dataset.problem_list)
        target_data.answer_list.extend(sub_dataset.answer_list)
        target_data.cost_list.extend(sub_dataset.cost_list)
        os.remove(sub_dataset_path) 
    
    # 保存到目标路径和备份路径
    with open(to_dataset_path, 'wb') as f:
        pickle.dump(target_data, f)
    with open(backup_path, 'wb') as f:
        pickle.dump(target_data, f)

def merge_list_data(from_dataset_path_list, to_dataset_path):
    ''' 将多个列表形式子数据集的数据合并到目标数据集  '''
    # 加载可能存在的目标数据集
    target_data = []
    if os.path.isfile(to_dataset_path) and os.path.exists(to_dataset_path):
        with open(to_dataset_path, 'rb') as f:
            target_data = pickle.load(f)  
    
    # 合并所有子数据集
    from_list = sorted(from_dataset_path_list)
    for sub_dataset_path in from_list:
        assert os.path.exists(sub_dataset_path)
        with open(sub_dataset_path, 'rb') as f:
            sub_dataset = pickle.load(f)
        target_data.extend(sub_dataset)
        os.remove(sub_dataset_path) 
    
    # 保存到目标路径和备份路径
    with open(to_dataset_path, 'wb') as f:
        pickle.dump(target_data, f)

'''
def init_dataset_for_multi_process(manager, existing_dataset:RawData):
    dataset = manager.Namespace()
    dataset.seed_list = manager.list()
    dataset.problem_list = manager.list()
    dataset.answer_list = manager.list()
    dataset.cost_list = manager.list()
    dataset.seed_list.extend(existing_dataset.seed_list)
    dataset.problem_list.extend(existing_dataset.problem_list)
    dataset.answer_list.extend(existing_dataset.answer_list)
    dataset.cost_list.extend(existing_dataset.cost_list)
    return dataset

def save_raw_dataset(dataset_size, save_interval, dataset_path, dataset):
    # 保存评估问题，这是数据生成过程中问题保存进程的目标方法
    # 保存轨迹
    def _save():
        assert dataset.answer_list is not None
        save_dataset = split_raw_dataset(dataset, 0, save_num)
        with open(dataset_path, 'wb') as f:
            pickle.dump(save_dataset, f)

    started_time = time.time()
    save_num = len(dataset.answer_list)
    last_printed = 0
    while save_num < dataset_size:
        if save_num != last_printed and save_num != 0 and save_num % save_interval == 0:
            print(f'[{round(time.time()-started_time,2)} s]: {save_num} raw data is saving', end="", flush=True)
            _save()
            print(f'\r[{round(time.time()-started_time,2)} s]: {save_num} raw data saved! ave cost {np.mean(dataset.cost_list)}'+' '*10, end="", flush=True)
            print()
            last_printed = save_num
        save_num = len(dataset.answer_list)
        time.sleep(0.01)

    save_num = dataset_size
    _save()
    print(f'[{round(time.time()-started_time,2)} s]: {save_num} raw data saved, ave cost {np.mean(dataset.cost_list)}, saving process exited\n\n')
'''