import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from models import Discriminator
from torch.optim import lr_scheduler
from utils import *
from PIL import Image
from vocab import Vocabulary, load_vocab
from data_loader import get_data_loader
from collections import OrderedDict
import re
import glob
def sortedStringList(array=[]):
    sortDict=OrderedDict()
    for splitList in array:
        sortDict.update({splitList:[int(x) for x in re.split("(\d+)",splitList)if bool(re.match("\d*",x).group())]})
    return [sortObjKey for sortObjKey,sortObjValue in sorted(sortDict.items(), key=lambda x:x[1])]

def fix_key(state_dict):
    # print(state_dict["state_dict"])
    new_state_dict = OrderedDict()
    for k, v in state_dict["state_dict"].items():
        # print(k)
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    # print(new_state_dict)
    return new_state_dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead


def main():
    """
    Training and validation.
    """

    word_map = VocabDict(args.input_dir+'/vocab_questions.txt')
    ## VQA
    print("Discriminator Creating...")    
    model = Discriminator(
        qst_vocab_size= len(word_map),
        word_embed_size= 300,
        num_layers= 2,
        hidden_size= 512,
        embed_size=1024
    ) 
    model = model.to(device)
    #torch.nn.DataParallel(
    optimizer = torch.optim.Adam([
        {"params": model.parameters()}
    ], lr=args.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

    # Loss function
    criterion = torch.nn.CrossEntropyLoss().to(device)

    data_loader = get_data_loader(
        input_dir=args.input_dir,
        input_vqa_train='train.npy',
        input_vqa_valid='valid.npy',
        max_qst_length=args.max_qst_length,
        batch_size=args.batch_size,
        num_workers=args.num_workers)
    trn_loader = data_loader["train"]
    val_loader = data_loader["valid"]

    # Epochs
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    tmp_train_losses = 0
    sigmoid = nn.Sigmoid()
    for epoch in range(args.epochs):
        model.train()

        for i, (qsts, labels) in enumerate(trn_loader):
            optimizer.zero_grad()

            # to device
            qsts = qsts.to(device)
            labels = labels.to(device)

            outputs = model(qsts)
            loss = criterion(outputs, labels)
        
            loss.backward()

            optimizer.step()
            if i % 100 == 0:

                prediction = outputs.detach().cpu().data.max(1)[1] #予測結果
                accuracy = prediction.eq(labels.detach().cpu().data).sum().numpy() / len(labels)   
                print("train: {}/{}, loss: {}, acc: {}".format(i,len(trn_loader),loss.item(),accuracy))
                train_losses.append(loss.item())
                train_accs.append(accuracy)

        model.eval()
        for i, (qsts, labels) in enumerate(val_loader):
            # to device
            qsts = qsts.to(device)
            labels = labels.to(device)
            outputs = model(qsts)

            loss = criterion(outputs, labels)
            if i % 100 == 0:
                prediction = outputs.detach().cpu().data.max(1)[1] #予測結果
                accuracy = prediction.eq(labels.detach().cpu().data).sum().numpy() / len(labels)   
                print("val: {}/{}, loss: {}, acc: {}".format(i,len(val_loader),loss.item(),accuracy))
                val_losses.append(loss.item())
                val_accs.append(accuracy)
        np.save("./loss_acc/train_loss.npy", train_losses)
        np.save("./loss_acc/train_accs.npy", train_accs)
        np.save("./loss_acc/val_loss.npy", val_losses)
        np.save("./loss_acc/val_accs.npy", val_accs)
        model_path = './checkpoints/model.pth'
        torch.save(model.state_dict(), model_path)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=str, default='./../Dataset',
                        help='input directory for visual question answering.')
    parser.add_argument('--lr', type=float,
            default=0.001, help='learning rate')
    parser.add_argument('--max_qst_length', type=int, default=30,
                        help='maximum length of question. the length in the VQA dataset = 26.')
    parser.add_argument('--epochs', type=int, default=3,
                        help='epoch num.')

    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size.')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of processes working on cpu.')
    parser.add_argument('--rnn_num_layers', type=int, default=1,
                        help='number of layers of rnn models.')
    parser.add_argument('--rec_unit', type=str,
            default='lstm', help='choose "gru", "lstm" or "elman"')
    parser.add_argument('--sample', default=False, 
            action='store_true', help='just show result, requires --checkpoint_file')
    parser.add_argument('--log_step', type=int,
            default=500, help='number of steps in between calculating loss')
    parser.add_argument('--num_hidden', type=int,
            default=512, help='number of hidden units in the RNN')
    parser.add_argument('--embed_size', type=int,
            default=512, help='number of embeddings in the RNN')
    parser.add_argument('--gamma', type=float,
        default=0.1, help='lr step scheduler gamma')
    parser.add_argument('--step_size', type=int,
        default=1, help='lr step scheduler step')

    args = parser.parse_args()


    with open("./checkpoints/params.json", mode="w") as f:
        json.dump(args.__dict__, f, indent=4)
    main()
