import argparse
import torch
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np


import os
import sys
os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..")  # for order_dataset，将父级目录加入执行列表
from bpp_env import bpp_env, multi_ccbpp_env
from rl_policy import Policy
from bpp_data_generator import BPPDataset


parser = argparse.ArgumentParser()

parser.add_argument('--test_size', default=200, type=int, help='Test data size')
parser.add_argument("--discount_factor", type=float, default=1)
parser.add_argument('--encoder_layer_num', type=int, default=2)
parser.add_argument('--head_num', type=int, default=4)
parser.add_argument("--embed_dim", type=int, default=128)
parser.add_argument("--ff_dim", type=int, default=128)
parser.add_argument("--clipping_const", type=float, default=10)

parser.add_argument('--fit', type=str, default="FF", help='fit scheme in CCBPP')

parser.add_argument("--batch_size", type=int, default=50)
parser.add_argument("--sample_size", type=int, default=16)

parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')
parser.add_argument('--data_type', type=str, default="normal", help='normal/large/uniform')

args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def npy2tensor(state):
    query, mask = state
    query = torch.FloatTensor(query).to(device)
    mask = torch.BoolTensor(mask).to(device)
    return [query, mask]


@torch.no_grad()
def eval(test_loader, policy):
    t1 = time.time()
    policy.eval()
    score_list = []

    iterator = tqdm(test_loader, unit='Batch')
    for i, sample_batched in enumerate(iterator):
        test_item_batch = sample_batched['weights'].to(device)
        sample_score = np.zeros([args.batch_size, args.sample_size])
        for i in range(args.sample_size):
            score = rollout(policy, test_item_batch)
            sample_score[:, i] = score
        score_list.append(sample_score.min(1))  # 采样的样本中取装箱数最小
    print(" %.4f seconds for each instance." % (
       (time.time() - t1) / args.test_size))

    return np.mean(score_list)


def rollout(policy, batch_item, baseline=20):

    if args.M == 1:
        env = bpp_env(batch_item.cpu().numpy(), args.B, args.Q, args.C,fit = args.fit)
    else:
        env = multi_ccbpp_env(batch_item.cpu().numpy(), args.B, args.Q, args.C,args.M, fit = args.fit)

    state, _, _ = env.reset()
    state = npy2tensor((state[0], state[1]))

    weights = batch_item[:, 0, :]
    classes = torch.zeros(args.batch_size, args.N, args.Q).to(batch_item.device)
    classes.scatter_(2, batch_item[:, 1, :].unsqueeze(2).long(), 1)

    policy.pre_forward(weights, classes, args.B)

    while True:
        a, log_prob = policy(state)  # (bs, )
        a = a.detach().cpu().long().numpy()
        next_state, r, done = env.step(a)
        if done:
            break
        state = npy2tensor((next_state[0], next_state[1]))

    score = - r
    return score

if __name__ == '__main__':
    import pickle

    """加载数据"""

    model_file_name = "CCBPP-RL-N%d-B%d-C%d-Q%d-M%d" % (
                                                    args.N, args.B, args.C, args.Q, args.M)


    test_dataset = BPPDataset("test", args.test_size, args.N, args.B, args.C, args.Q, args.M,
                              args.data_type, data_dir ="../ccbpp_data/")
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)


    model_params = {
        'input_dim': 1 + args.Q,
        'embedding_dim': args.embed_dim,
        'sqrt_embedding_dim': args.embed_dim ** (1 / 2),
        'encoder_layer_num': args.encoder_layer_num,
        'head_num': args.head_num,
        'logit_clipping': args.clipping_const,
        'ff_dim': args.ff_dim,
        'eval_type': 'argmax',
    }

    policy = Policy(**model_params).to(device)
    policy.load_state_dict(torch.load(f'../rl_result/%s_attn_step_params.pt'%(model_file_name)))
    test_score = eval(test_loader, policy)

    print(f'testing score:{test_score:.2f}')
