from __future__ import print_function
import os
import time
import datetime
import argparse
import numpy
import random
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

import time
import data
import models
from utils import *
import options
import glob
import shutil
import logging
import youtokentome as yttm
from onlineoptim import SGDHD,SGDTLR
from tools import Dictionary, Corpus, repackage_hidden, batchify, get_batch
from torch.autograd import Variable
from collections import Counter


Args = options.Options()
args = Args.parse()
device = "cuda"

bpe_model = yttm.BPE(model="data/bpe.32000.model")

test_loader = data.load("data/",
					     split='test',
					     batch_size=1,
					     shuffle=False,
					     bpe_model=bpe_model)

model = models.Seq2SeqTransformer(
    num_encoder_layers=3,
    num_decoder_layers=3,
    emb_size=512,
    nhead=8,
    vocab_size=bpe_model.vocab_size(),
    dim_feedforward=512,
    dropout=args.dropout_rate)
model = model.to(device)
model.load_state_dict(torch.load("/home/jing/AutoDropBatch/codes/machine_translation/checkpoints/adam/run_ms_0/best_model.pth.tar"))

with torch.no_grad():
    score = compute_bleu(model, test_loader, bpe_model, device)

print(score)