import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import json
import argparse
import numpy as np
import torch
from environment.wrapper import LMPromptEnv
from environment.DDP_wrapper import DDP_LMPromptEnv
from data.used.make_data import get_op_data_v1, get_cvrp_data_v1, get_pctsp_data_v1, get_tsp_data_v1, get_tsp_data_v2, get_tsp_data_v3, EXAMPLE_RENDER
from utils.utils import split_dataproblem, load_data, set_seed

from typing import List
from tqdm import tqdm
from environment.used.Env_cvrp_v1 import CVRP_V1, DDP_CVRP_V1
from environment.used.Env_tsp_v1 import TSP_V1, DDP_TSP_V1
from environment.used.Env_tsp_v2 import TSP_V2, DDP_TSP_V2
from environment.used.Env_op_v1 import OP_V1, DDP_OP_V1
from environment.used.Env_op_v2 import OP_V2, DDP_OP_V2
from environment.used.Env_pctsp_v1 import PCTSP_V1, DDP_PCTSP_V1
from dataloader.code.problem_loader import get_obj_value, get_obj_value_ddp, ProblemLoader
from environment.used.BaseEnv_COP import DataProblem

def get_eval_components(args, env_name):
    if env_name == 'Env_CVRP_V1':
        data = load_data(data_name='CVRP_V1', data_file_name=f'cvrp{args.cvrp_node_num}_{args.eval_problem_set}')
        problems = split_dataproblem(data, 0, args.eval_iters_COP)
        dataset, ddp_datset = get_cvrp_data_v1(args, data_type='prompt', get_dataset=True, get_ddp_dataset=True)      
        env_ddp_builder = lambda: DDP_CVRP_V1(node_num=args.cvrp_node_num, batch_size=args.problem_batch_size)
        env_basic_builder = lambda: CVRP_V1(node_num=args.cvrp_node_num)
    elif env_name == 'Env_OP_V1':
        data = load_data(data_name='OP_V1', data_file_name=f'op{args.op_node_num}_{args.eval_problem_set}')
        problems = split_dataproblem(data, 0, args.eval_iters_COP)
        dataset, ddp_datset = get_op_data_v1(args, data_type='prompt', get_dataset=True, get_ddp_dataset=True)           
        env_ddp_builder = lambda: DDP_OP_V1(node_num=args.op_node_num, batch_size=args.problem_batch_size)
        env_basic_builder = lambda: OP_V1(node_num=args.op_node_num)
        '''
    elif env_name == 'Env_OP_V2':
        data = load_data(data_name='OP_V1', data_file_name=f'op{args.cvrp_node_num}_{eval_args.eval_problem_set}')
        #data.answer_list = [[a+1 for a in answer] for answer in data.answer_list]  # 使用 OP_V1 版本数据时，action要处理为action+=1
        problems = split_dataproblem(data, 0, args.eval_iters_COP)
        dataset, ddp_datset = get_op_data_v2(args, data_type='prompt', get_dataset=True, get_ddp_datasetproblems=True)       
        env_ddp_builder = lambda: DDP_OP_V2(node_num=args.op_node_num, batch_size=eval_args.problem_batch_size)
        env_basic_builder = lambda: OP_V2(node_num=args.op_node_num)
        '''
    elif env_name == 'Env_PCTSP_V1':
        data = load_data(data_name='PCTSP_V1', data_file_name=f'pctsp{args.cvrp_node_num}_{args.eval_problem_set}')
        problems = split_dataproblem(data, 0, args.eval_iters_COP)
        dataset, ddp_datset = get_pctsp_data_v1(args, data_type='prompt', get_dataset=True, get_ddp_dataset=True)     
        env_ddp_builder = lambda: DDP_PCTSP_V1(node_num=args.pctsp_node_num, batch_size=args.problem_batch_size)
        env_basic_builder = lambda: PCTSP_V1(node_num=args.pctsp_node_num)
    elif env_name == 'Env_TSP_V1':
        data = load_data(data_name='TSP_V1', data_file_name=f'tsp{args.tsp_city_num}_{args.eval_problem_set}')
        problems = split_dataproblem(data, 0, args.eval_iters_COP)       
        dataset, ddp_datset = get_tsp_data_v1(args, data_type='prompt', get_dataset=True, get_ddp_dataset=True)
        env_ddp_builder = lambda: DDP_TSP_V1(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size)
        env_basic_builder = lambda: TSP_V1(num_nodes=args.tsp_city_num)
    elif env_name == 'Env_TSP_V2':
        data = load_data(data_name='TSP_V2', data_file_name=f'tsp{args.tsp_city_num}_{args.eval_problem_set}')
        problems = split_dataproblem(data, 0, args.eval_iters_COP)       
        dataset, ddp_datset = get_tsp_data_v2(args, data_type='prompt', get_dataset=True, get_ddp_dataset=True)
        env_ddp_builder = lambda: DDP_TSP_V2(num_nodes=args.tsp_city_num, batch_size=args.problem_batch_size)
        env_basic_builder = lambda: TSP_V2(num_nodes=args.tsp_city_num)
    else:
        raise False

    return dataset[0], ddp_datset[0], problems, env_ddp_builder, env_basic_builder


def _assert_step(
    ddp_current_seq, ddp_info_obs, current_seq, info_obs, 
    ddp_action_mask, action_mask,
    ddp_info_obj = None, info_obj = None,
    ddp_reward=None, ddp_terminated=None, ddp_truncated=None, reward=None, terminated=None, truncated=None
):
    for k in info_obs.keys():
        if info_obs[k].size == 1:
            assert abs(info_obs[k].item() - ddp_info_obs[k]) < 1e-4
        else:
            assert np.array_equal(info_obs[k], ddp_info_obs[k])
    assert torch.equal(current_seq, ddp_current_seq) or (current_seq - ddp_current_seq).abs().max().item() <= 1
    assert np.array_equal(action_mask, ddp_action_mask)
    if ddp_reward is None:
        assert ddp_terminated is None
        assert ddp_truncated is None
        assert reward is None
        assert terminated is None
        assert truncated is None
    else:
        assert terminated == ddp_terminated
        assert truncated == ddp_truncated
        assert abs(reward['AM'] - ddp_reward['AM']) < 1e-4 
        assert abs(reward['DB1'] - ddp_reward['DB1']) < 1e-4 
        assert (ddp_info_obj is None and info_obj is None) or (abs(ddp_info_obj - info_obj) < 1e-4) 

def assert_all(problems:DataProblem, env_ddp:DDP_LMPromptEnv, env_basics:List[LMPromptEnv]=None, best_obj_array=None, random_obj_array=None, use_default_policy_obj=False):
    set_seed(42)
    batch_size = env_ddp.env.batch_size

    # 构造 ProblemLoader
    dataloader_problem = ProblemLoader(problems)
    problem_dataset_size = len(dataloader_problem)
    dataloader_problem.reset()

    # 初始化第一批问题
    problem_info_list, problem_info_ddp, problem_idx_array = dataloader_problem.get_problem(num=batch_size)
    ddp_current_seq, ddp_info = env_ddp.reset(options={
        'problem_info':problem_info_ddp, 
        'problem_idx':list(range(batch_size)),
        'problem_obj':(best_obj_array[problem_idx_array], random_obj_array[problem_idx_array]),
        'use_default_policy_obj': use_default_policy_obj
    })
    ddp_action_mask = env_ddp.get_action_mask(hard_action_constraint=True)[0]
    
    for i in range(batch_size):
        env_basic = env_basics[i]
        current_seq, info = env_basic.reset(options={
            'problem_info':problem_info_list[i],
            'problem_obj':(best_obj_array[i], random_obj_array[i]),
            'use_default_policy_obj': use_default_policy_obj
        })
        action_space_mask = env_basic.get_action_mask(hard_action_constraint=True)[0]
        _assert_step(
            ddp_current_seq[i], {k: v[i] for k, v in ddp_info['obs'].items()}, current_seq, info['obs'],
            ddp_action_mask[i], action_space_mask
        )
    
    # 开始交互
    slove_cnt = 0
    done_idx_list = []  # 维护当前timestep时已经结束的问题索引，仅用于并行环境 & 串行环境一致性检查
    epi_return = {'AM':[], 'DB1':[]}
    with tqdm(total=problem_dataset_size, desc=f'Rollout Assert Checking') as pbar:
        while True:
            # 随机生成一批动作
            actions = []
            ddp_action_space = env_ddp.env.get_action_value_space(hard_action_constraint=True)[0]
            for i in range(batch_size):
                act = 0
                if i not in done_idx_list:
                    act = np.random.choice(ddp_action_space[i])
                    action_space = env_basics[i].env.get_action_value_space(hard_action_constraint=True)[0]
                    assert np.array_equal(action_space, ddp_action_space[i])
                actions.append(act)
                    
            # 批量环境转移
            ddp_current_seq, ddp_reward, ddp_terminated, ddp_truncated, ddp_info = env_ddp.step(np.array(actions))
            ddp_action_mask = env_ddp.get_action_mask(hard_action_constraint=True)[0]
            assert ddp_truncated.sum() == 0

            for i in range(batch_size):
                if i not in done_idx_list:
                    current_seq, reward, terminated, truncated, info = env_basics[i].step(actions[i])
                    action_space_mask = env_basics[i].get_action_mask(hard_action_constraint=True)[0]

                    _assert_step(
                        ddp_current_seq[i], {k: v[i] for k, v in ddp_info['obs'].items()}, current_seq, info['obs'], 
                        ddp_action_mask[i], action_space_mask,
                        ddp_info['obj'][i], info['obj'],
                        {k: v[i] for k, v in ddp_reward.items()}, ddp_terminated[i], ddp_truncated[i], 
                        reward, terminated, truncated
                    )

            # 获取随机策略生成轨迹的 obj value
            done_idx = []   # 环境转移后终止的batch内问题索引
            done_idx.extend(np.where(ddp_terminated)[0].tolist())
            done_idx.extend(np.where(ddp_truncated)[0].tolist())
            slove_cnt += len(done_idx)

            # 更新进度条
            epi_return['AM'].extend(ddp_reward['AM'].tolist())
            epi_return['DB1'].extend(ddp_reward['DB1'].tolist())
            ret_AM = 0 if len(done_idx) == 0 else np.mean(epi_return['AM'])
            ret_DB1 = 0 if len(done_idx) == 0 else np.mean(epi_return['DB1'])
            info = {'ret_AM' : f'{ret_AM:.4f}', 'ret_DB1': f'{ret_DB1:.4f}'}
            pbar.set_postfix(info)
            pbar.update(len(done_idx))

            # 处理结束的问题
            if len(done_idx) != 0:
                if slove_cnt >= problem_dataset_size:
                    break
                
                # 加载一批新问题，数量和本 timestep 内结束的问题数量一致，直到问题数量不足为止
                problem_info_list, problem_info_ddp, new_problem_idx_array = dataloader_problem.get_problem(num=len(done_idx))            
                problem_num = len(problem_info_ddp[2])                              # 有效问题数量（<= len(done_idx)）
                done_idx_list.extend(done_idx[problem_num:])                        # 有效问题数量不足len(done_idx)时，把有效问题插入done_idx的前一部分，超出的索引直接记为done

                # reset ddp env
                problem_idx_array[done_idx[:problem_num]] = new_problem_idx_array   # 有效问题数量不足len(done_idx)时，把有效问题插入done_idx的前一部分
                if problem_num != 0:
                    ddp_current_seq, ddp_info = env_ddp.reset(options={
                        'problem_info':problem_info_ddp, 
                        'problem_idx':done_idx[:problem_num],
                        'problem_obj':(best_obj_array[new_problem_idx_array], random_obj_array[new_problem_idx_array]),
                        'use_default_policy_obj': use_default_policy_obj
                    })
            
                # reset normal env
                ddp_action_space = env_ddp.env.get_action_value_space(hard_action_constraint=True)[0]
                ddp_action_mask = env_ddp.get_action_mask(hard_action_constraint=True)[0]
                for i, info, problem_idx in zip(done_idx[:problem_num], problem_info_list, new_problem_idx_array):
                    env_basic = env_basics[i]
                    current_seq, info = env_basic.reset(options={
                        'problem_info':info,
                        'problem_obj':(best_obj_array[problem_idx], random_obj_array[problem_idx]),
                        'use_default_policy_obj': use_default_policy_obj
                    })
                    action_space_mask = env_basic.get_action_mask(hard_action_constraint=True)[0]
                    _assert_step(
                        ddp_current_seq[i], {k: v[i] for k, v in ddp_info['obs'].items()}, current_seq, info['obs'], 
                        ddp_action_mask[i], action_space_mask
                    )

if __name__ == "__main__":
    batch_size = 500
    batch_num = 20

    exp_name = 'TSP20(100000)-PCTSP20(100000)-OP20(100000)-CVRP20(100000)_500_504_8_12'
    with open(f'{base_path}/ckpt/{exp_name}/config.json', 'r') as f:
        config_dict = json.load(f)
        args = argparse.Namespace(**config_dict)

    args.disable_visited_obs = False
    args.use_prefix = False
    args.use_prompt = True
    args.use_mem = False
    args.tsp_city_num = args.pctsp_node_num = args.op_node_num = args.cvrp_node_num = 20
    args.problem_batch_size = batch_size
    args.eval_problem_set = 'problem'
    args.traj_type = 'all'
    args.eval_iters_COP = batch_size * batch_num
    # -------------------------------------------------------------------------------------------------------------------
    #env_names = ['Env_PCTSP_V1', 'Env_TSP_V2', 'Env_OP_V1', 'Env_CVRP_V1', ]
    env_names = ['Env_TSP_V2', ]
    args.tsp_city_num = 200
    args.n_position = 1024      # for prompt test
    args.data_num_tsp = 10000   # for prompt test

    enable_assert = False
    disable_tqdm = False
    use_default_policy_obj = False

    for env_name in env_names:
        set_seed(42)
        dataset, dataset_ddp, problems, env_ddp_builder, env_basic_builder = get_eval_components(args, env_name)
        problem_dataset_size = min(len(problems.answer_list), args.eval_iters_COP)

        # 创建环境对象
        env_ddp = DDP_LMPromptEnv(env_ddp_builder(), args, dataset_ddp, 'moving_prompt',)
        env_basics = [LMPromptEnv(env_basic_builder(), args, dataset, 'moving_prompt',) for _ in range(batch_size)]

        # prompt 测试
        set_seed(42)
        ddp_prompt = env_ddp.get_prompt(strict_length=True, minimal_expert_data=True)
        set_seed(42)
        for i in tqdm(range(batch_size), desc=f'Checking Prompt of {env_name}'):
            env_basic = env_basics[i]
            prompt = env_basic.get_prompt(strict_length=True, minimal_expert_data=True)
            torch.equal(prompt, ddp_prompt[i]) or (prompt - ddp_prompt[i]).abs().max().item() <= 1

        # 计算评估问题上专家策略和随机策略达成的 obj metric
        best_obj_array_ddp, random_obj_array_ddp = get_obj_value_ddp(problems, env_ddp, disable_tqdm)
        #best_obj_array, random_obj_array = get_obj_value(problems, env_basics[0], disable_tqdm)
        assert_all(problems, env_ddp, env_basics, best_obj_array_ddp, random_obj_array_ddp, use_default_policy_obj)
        
        #print(f'Ave obj value of [{problem_dataset_size}] samples in [{env_name}] is: \n\trandom policy: [{random_obj_array.mean()}]; [{random_obj_array_ddp.mean()}]\n\tbest polciy:   [{best_obj_array.mean()}]')
        print(f'Ave obj value of [{problem_dataset_size}] samples in [{env_name}] is: \n\trandom policy: [{random_obj_array_ddp.mean()}]\n\tbest polciy:   [{best_obj_array_ddp.mean()}]')
        print('\n\n')
        