import argparse
import torch
import numpy as np
from tqdm import tqdm
import sys
sys.path.append("../")
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from bpp_data_generator import BPPDataset
from torch.utils.data import DataLoader

from bpp_env import bpp_env, multi_ccbpp_env
from rl_policy import Policy


parser = argparse.ArgumentParser()

parser.add_argument('--train_size', default=6400, type=int, help='Training data size')
parser.add_argument('--val_size', default=320, type=int, help='Validation data size')

parser.add_argument("--discount_factor", type=float, default=1)
parser.add_argument('--encoder_layer_num', type=int, default=2)
parser.add_argument('--head_num', type=int, default=4)
parser.add_argument("--embed_dim", type=int, default=128)
parser.add_argument("--ff_dim", type=int, default=128)
parser.add_argument("--clipping_const", type=float, default=10)

parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--test_interval", type=int, default=100)
parser.add_argument("--fine_tune", action="store_true")
parser.add_argument("--sample_size", type=int, default=32)

parser.add_argument('--fit', type=str, default="FF", help='fit scheme in BPP')
parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')
parser.add_argument('--data_type', type=str, default="normal", help='normal/large/uniform')

args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter('runs/reinforce')


def npy2tensor(state):
    query, mask = state
    query = torch.FloatTensor(query).to(device)
    mask = torch.BoolTensor(mask).to(device)
    return [query, mask]


def normalize_return(rewards, gamma=args.discount_factor):
    discounted_returns = torch.zeros(len(rewards)).float().to(device)
    cumulative = 0
    for t in reversed(range(len(rewards))):
        cumulative = cumulative * gamma + rewards[t]
        discounted_returns[t] = cumulative

    return discounted_returns

# @torch.no_grad()
# def test(test_loader, policy):
#     policy.eval()
#     score_list = []
#     iterator = tqdm(test_loader, unit='Batch')
#     for i, sample_batched in enumerate(iterator):
#         test_item_batch = sample_batched['weights'].to(device)
#         _, score = rollout(policy, test_item_batch)
#         score_list.append(score)
#
#     return np.mean(score_list)


@torch.no_grad()
def test(test_loader, policy):
    policy.eval()
    score_list = []

    iterator = tqdm(test_loader, unit='Batch')
    for i, sample_batched in enumerate(iterator):
        test_item_batch = sample_batched['weights'].to(device)
        sample_score = np.zeros([args.batch_size, args.sample_size])
        for i in range(args.sample_size):
            _, score = rollout(policy, test_item_batch)
            sample_score[:, i] = score
        score_list.append(sample_score.min(1))  # 采样的样本中取装箱数最小

    return np.mean(score_list)

def rollout(policy, batch_item, baseline= -20):

    if args.M == 1:
        env = bpp_env(batch_item.cpu().numpy(), args.B, args.Q, args.C,fit = args.fit)
    else:
        env = multi_ccbpp_env(batch_item.cpu().numpy(), args.B, args.Q, args.C,args.M, fit = args.fit)
    state, _,_ = env.reset()
    state = npy2tensor((state[0],state[1]))

    weights = batch_item[:, 0, :]
    classes = torch.zeros(args.batch_size, args.N, args.Q).to(batch_item.device)
    classes.scatter_(2, batch_item[:, 1:, :].transpose(1,2).long(), 1)

    policy.pre_forward(weights, classes, args.B)

    logprob_list = torch.zeros(size=(args.batch_size, 0)).to(device)  # shape: (batch_size, 0~)

    while True:
        a, log_prob = policy(state)  # (bs, )
        logprob_list = torch.cat([logprob_list, log_prob[:, None]], dim=1)
        a = a.detach().cpu().long().numpy()
        next_state, r, done = env.step(a)
        group_state, mask, _ = next_state
        if done:
            break
        state = npy2tensor((group_state, mask))

    log_probs = logprob_list.sum(1)  # 收集完batch后维度为(batch, n_orders)，加和之后为(bs, )
    adv = (torch.FloatTensor(r) - baseline).to(device)  # (bs, )
    loss = (- adv * log_probs).mean()
    score = - r.mean()
    return loss, score


if __name__ == '__main__':
    import pickle

    model_file_name = "CCBPP-RL-N%d-B%d-C%d-Q%d-M%d" % (
                                                args.N, args.B, args.C, args.Q, args.M)

    """加载数据"""

    train_dataset = BPPDataset("train", args.train_size, args.N, args.B, args.C,args.Q, args.M,
                               args.data_type, data_dir ="../ccbpp_data/")
    test_dataset = BPPDataset("valid", args.val_size, args.N, args.B, args.C, args.Q, args.M,
                              args.data_type, data_dir ="../ccbpp_data/")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)


    model_params = {
        'input_dim': 1 + args.Q,
        'embedding_dim': args.embed_dim,
        'sqrt_embedding_dim': args.embed_dim ** (1 / 2),
        'encoder_layer_num': args.encoder_layer_num,
        'head_num': args.head_num,
        'logit_clipping': args.clipping_const,
        'ff_dim': args.ff_dim,
        'eval_type': 'argmax',
    }

    policy = Policy(**model_params).to(device)
    optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.99)

    step = len(train_loader)
    best_score = float('inf')
    for epoch in range(args.num_epochs):
        iterator = tqdm(train_loader, unit='Batch')
        #pbar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader))
        for i, sample_batched in enumerate(iterator):

            train_item_batch = sample_batched['weights'].to(device)
            #weights = train_item_batch[:, 0, :]
            #classes = torch.zeros(params.batch_size, params.N, params.Q).to(train_item_batch.device)
            #classes.scatter_(2, train_item_batch[:, 1, :].unsqueeze(2).long(), 1)


            policy.train()
            optimizer.zero_grad()
            loss, score = rollout(policy, train_item_batch)

            loss.backward()
            optimizer.step()
            # scheduler.step()

            iterator.set_description(f"LOSS : {loss.item():.3f}, Train BIN VALUE : {score:.2f}")
            writer.add_scalar('loss', loss.item(), (epoch*step+i))
            writer.add_scalar('-reward', score, (epoch*step+i))

            if (i+1) % args.test_interval == 0:
                # 测试一下在测试集的效果
                test_score = test(test_loader, policy)
                if test_score < best_score:
                    best_score = test_score
                    print(f'Epoch {epoch+1} Best testing score:{best_score:.2f}')
                    torch.save(policy.state_dict(), f'../rl_result/%s_attn_step_params.pt'%(model_file_name))
                print(f'Epoch {epoch + 1} testing score:{test_score:.2f}')

        torch.save(policy.state_dict(), f'../rl_result/%s_attn_step_params.pt'%(model_file_name))

    print(f'Final Best testing score:{best_score:.2f}')
    print(f"Initial learning rate: {optimizer.defaults['lr']:.6f}, Last learning rate: {scheduler.get_last_lr()[0]:.6f}")
