import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import torch
import argparse
from utils.utils import set_seed, create_file_if_not_exist, create_folder_overwrite_if_exist, str2bool, load_model
from data.used.make_data import get_01bp_data, get_tsp_data_v1, get_tsp_data_v2, get_tsp_data_v3, EXAMPLE_RENDER
from argparse import Namespace
from dataloader.code.dataset import RLFullDataset, BlendableDataset
from evaluate_test.evaluate_utils import eval_policy
from dataloader.code.data_samplers import build_training_data_loader
from environment.wrapper import LMPromptEnv

def build_dataloader(args:Namespace, eval_args:Namespace, dataset:RLFullDataset):
    # split training set and evaluation set        
    dataset_train, dataset_val = dataset.split_dataset(args.split)

    # build BlendableDataset
    dataset_train = BlendableDataset(
        [dataset_train, ], 
        [1, ],
        batch_size=1,
        check_visited=False,
        log_data=eval_args.traindata_logger
    )
    dataset_val = BlendableDataset(
        [dataset_val, ], 
        [1, ],
        batch_size=1,
    )

    # build dataloader
    '''
    if eval_args.batch_num == 0:
        eval_args.batch_num = len(dataset_train)
        sample_num_per_training_epoch = None
    else:
        sample_num_per_training_epoch = eval_args.batch_num
    '''
    dataloader_train = build_training_data_loader(
        eval_args, 
        dataset_train, 
        epoch_total_samples=1, 
        is_eval=False
    )

    '''
    if eval_args.eval_batch_num == 0:
        eval_args.eval_batch_num = len(dataset_val)
        sample_num_per_evaluation_epoch = None
    else:
        sample_num_per_evaluation_epoch = eval_args.eval_batch_num
    '''
    dataloader_val = build_training_data_loader(
        eval_args, 
        dataset_val, 
        epoch_total_samples=1,
        is_eval=True
    )

    return dataloader_val, dataloader_train, dataset_val, dataset_train

def get_loss(gato, dataloader, dataset, eval_args, logger=None, is_train=False):
    desc = f'Train loss' if is_train else 'Eval loss'
    for batch in dataloader:
        rl_task_input = batch[0]
        rl_task_input.to(device=device)

        with torch.set_grad_enabled(False):
            _, loss, _, _ = gato(rl_task_input)
        print(f'{desc}:\t{loss.item()}')

        # log train data if necessary
        idx = None
        if logger is not None and eval_args.traindata_logger:
            logged_data = dataset.get_logged_data()
            logger[eval_args.eval_env].log_data(logged_data, seed=42, is_train=False)
            
            idxs = list(logged_data.values())[0]
            idx = idxs[0]
            assert len(idxs) == 1
            
            '''
            res, (observations, actions) = dataset.__getitem__(idx, with_raw_obs=True)
            print(res.tensor_seq)
            print()
            print(observations['position'][0])
            print()
            print(actions)
            '''
    return idx

    #return np.array(epoch_losses).mean()
if __name__ == "__main__":
    create_file_if_not_exist(f'{base_path}/model/test.txt')

    # eval paras
    parser = argparse.ArgumentParser()
    parser.add_argument("--bp-max-item-num", type=int, default=7)
    parser.add_argument("--tsp-city-num", type=int, default=10,)
    parser.add_argument("--ckpt-path", type=str, default=None,)
    parser.add_argument("--snapshot-path", type=str, default=None,)
    parser.add_argument("--eval-env", type=str, default='Env_01BP_V1')
    parser.add_argument("--eval-problem-set", type=str, default='problem', choices={'problem', 'train_problem'})
    parser.add_argument("--eval-iters-COP", type=int, default=0,)   # 设0则用整个评估问题集
    parser.add_argument("--eval-iters-RL", type=int, default=5,)
    parser.add_argument("--batch-num", type=int, default=0,)
    parser.add_argument("--batch-size", type=int, default=1,)
    parser.add_argument("--eval-batch-num", type=int, default=0,)
    parser.add_argument("--eval-batch-size", type=int, default=1,)
    parser.add_argument("--regen-times", type=int, default=5,)
    parser.add_argument("--policy-logger", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--traindata-logger", type=str2bool, default=False, nargs="?", const=True)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--use-prompt", type=str2bool, default=True)
    parser.add_argument("--use-mem", type=str2bool, default=True)
    parser.add_argument(
        "--dataloader-type",
        type=str,
        default="sequential",
        choices=["sequential", "random"],
        help="Fetch data sequentially or out-of-order",
    )
    parser.add_argument("--prompt-strategy", type=str, 
        default="stochastic_subseq;moving_prompt",
        choices={
            "stochastic_timestep;moving_prompt",
            "stochastic_subseq;moving_prompt",
            "stochastic_timestep;fixed_prompt",
            "stochastic_subseq;fixed_prompt",
        },
    )

    eval_args = parser.parse_args()
    eval_args.ckpt_path = 'TSP5(DDP-10000)_80_504_6_5/best/0.78_seed43_epoch40.pt'
    eval_args.eval_env = 'Env_TSP_V2'
    eval_args.tsp_city_num = 5
    eval_args.eval_problem_set='train_problem'
    eval_args.policy_logger = True
    eval_args.traindata_logger = True
    eval_args.use_prompt = False
    eval_args.use_mem = True
    set_seed(42)

    assert (eval_args.ckpt_path is None) ^ (eval_args.snapshot_path is None) 
    assert eval_args.dataloader_type == "sequential"    # 本脚本仅支持单卡
    
    # load ckpt
    exp_name = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')] if eval_args.ckpt_path is not None else \
                eval_args.snapshot_path[:eval_args.snapshot_path.find('/')]
    args, gato, current_epoch = load_model(
        config_path=f'{base_path}/ckpt/{exp_name}/config.json', 
        ckpt_path=None if eval_args.ckpt_path is None else f'{base_path}/ckpt/{eval_args.ckpt_path}',
        snapshot_path=None if eval_args.snapshot_path is None else f'{base_path}/ckpt/{eval_args.snapshot_path}'
    )
    args.use_mem = eval_args.use_mem
    args.use_prompt = eval_args.use_prompt
    args.bp_max_item_num = eval_args.bp_max_item_num
    args.tsp_city_num = eval_args.tsp_city_num

    device = torch.device(f"cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() >= 1 else "cpu")
    gato = gato.to(device)

    # load prompt dataset and eval problem
    dataset_train, dataset_prompt, env_problem, env_para = None, None, None, None
    if eval_args.eval_env == 'Env_01BP_V1':
        dataset_train = get_01bp_data(args, data_type='train')
        dataset_prompt = get_01bp_data(args, data_type='prompt')
        env_problem = get_01bp_data(args, data_type=eval_args.eval_problem_set)
        env_para = {'max_quantity': eval_args.bp_max_item_num}
    elif eval_args.eval_env == 'Env_TSP_V1':
        dataset_train = get_tsp_data_v1(args, data_type='train')
        dataset_prompt = get_tsp_data_v1(args, data_type='prompt')
        env_problem = get_tsp_data_v1(args, data_type=eval_args.eval_problem_set)
        env_para = {'num_nodes': eval_args.tsp_city_num}
    elif eval_args.eval_env == 'Env_TSP_V2':
        dataset_train = get_tsp_data_v2(args, data_type='train')
        dataset_prompt = get_tsp_data_v2(args, data_type='prompt')
        env_problem = get_tsp_data_v2(args, data_type=eval_args.eval_problem_set)
        env_para = {'num_nodes': eval_args.tsp_city_num}
    else:
        raise False
    dataset_train = dataset_train[0]
    dataset_prompt = dataset_prompt[0]
    args.eval_env_name = dataset_prompt.env_name
    args.eval_dataset_name = dataset_prompt.dataset_name
    
    # build episode render if we need to check generated episodes during training
    logger = None
    if eval_args.policy_logger:
        create_folder_overwrite_if_exist(f'{base_path}/visualize/eval/log/{args.eval_env_name}/{args.eval_dataset_name}')
        logger = EXAMPLE_RENDER[args.eval_env_name]()
    logger = {eval_args.eval_env: logger}

    # build dataloader
    (
        dataloader_val, 
        dataloader_train, 
        dataset_val, 
        dataset_train 
    ) = build_dataloader(args, eval_args, dataset_train)

    # check eval loss
    gato.eval()
    gato.transformer.same_length = False        # use normal context length when loss calculating (TransformerXL back bone)
    with torch.inference_mode():
        problem_idx = get_loss(gato, dataloader_train, dataset_train, eval_args, logger, is_train=True)
        #eval_loss = get_loss(gato, dataloader_val, dataset_val, eval_args, is_train=False)
        pass

    env_problem.answer_list = [env_problem.answer_list[problem_idx], ]
    env_problem.problem_list = [env_problem.problem_list[problem_idx], ]
    #env_problem.answer_list = env_problem.answer_list[:50]
    #env_problem.problem_list = env_problem.problem_list[:50]
    envs_problems = {eval_args.eval_env: env_problem}

    # build envs for evaluation
    eval_prompt_strat = eval_args.prompt_strategy.split(";")[-1]     # moving_prompt
    env = LMPromptEnv(
        dataset_prompt.env_name, 
        args, 
        dataset_prompt, 
        eval_prompt_strat,
        **env_para
    )
    envs = [env,]

    # figure out eval rollout times of each env 
    eval_iters_dict = {
        'COPTask': eval_args.eval_iters_COP, 
        'RLTask': eval_args.eval_iters_RL
    }
    eval_iter = len(env_problem.answer_list) if eval_iters_dict[env.task_type] == 0 else \
                 min(len(env_problem), eval_iters_dict[env.task_type]) 
    eval_iters = {env.env_name: eval_iter}

    # evaluate policy
    args.eval_env_names = [args.eval_env_name,]
    args.eval_dataset_names = [args.eval_dataset_name,]
    eval_setting = {
        'greedy_free': lambda: eval_policy(args, envs, gato, eval_iters, logger, 42, envs_problems, device,
                sample_action=False, hard_action_constraint=False),
        'sample_free': lambda: eval_policy(args, envs, gato, eval_iters, logger, 42, envs_problems, device,
                sample_action=True, hard_action_constraint=False),
        f'vote{eval_args.regen_times}_free': lambda: eval_policy(args, envs, gato, eval_iters, logger, 42, envs_problems, device,
                sample_action=True, hard_action_constraint=False, regen_times=eval_args.regen_times),
        'greedy_cst': lambda: eval_policy(args, envs, gato, eval_iters, logger, 42, envs_problems, device,
                sample_action=False, hard_action_constraint=True),
        'sample_cst': lambda: eval_policy(args, envs, gato, eval_iters, logger, 42, envs_problems, device,
                sample_action=True, hard_action_constraint=True),
        f'vote{eval_args.regen_times}_cst': lambda: eval_policy(args, envs, gato, eval_iters, logger, 42, envs_problems, device,
                sample_action=True, hard_action_constraint=True, regen_times=eval_args.regen_times),
    }
    result = {
        setting: {'return': [], 'safe': [], 'time': []} 
        for setting in eval_setting.keys()
    }

    gato.eval()
    gato.transformer.same_length = args.use_mem         # use fixed context length when rollout with mem (TransformerXL back bone)
    #with torch.no_grad():
    with torch.inference_mode():
        for setting, eval_func in eval_setting.items():
            epi_return, epi_safe, epi_time = eval_func()
            result[setting]['return'].extend(epi_return)
            result[setting]['safe'].extend(epi_safe)
            result[setting]['time'].extend(epi_time)