# Copyright (c) anonymous All Rights Reserved.
# Licensed under the BSD 3-Clause Clear License [see LICENSE for details]

import argparse
import json
import logging
import os
import random
from io import open
import math
import sys

from time import gmtime, strftime
from timeit import default_timer as timer

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from transformers import BertConfig

from models.module_s import TransformerModuleNet
from datasets.clevr_dataset_program import CLEVRDataset

from cfgs.path_cfgs import PATH

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--from_pretrained", default='', type=str, help="model path.",
    )
    parser.add_argument(
        "--num_module_layers", type=int, default=1, help="Number of module layers.",
    )
    parser.add_argument(
        "--arch", type=str, default='s', help="Network architecture (s, t, ta)", choices=['s', 't', 'ta'],
    )
    parser.add_argument(
        "--test", type=str, default='valB', help="target dataset", choices=['valA', 'valB', 'val']
    )
    parser.add_argument(
        "--dump", action="store_true" , help="whether dump the results."
    )
    args = parser.parse_args()
    print(args)

    print("import path cfgs")
    path_cfgs = PATH()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    config = BertConfig.from_pretrained(path_cfgs.root_path + 'config/bert_base_6layer_6conect.json')        
    config.num_module_layer = args.num_module_layers
    config.use_location_embed = True
    print(config)
    print("num_module_layer:", config.num_module_layer)

    if args.arch == 's':
        from models.module_s import TransformerModuleNet
        from datasets.clevr_dataset_program_s import CLEVRDataset
        config.num_region = 37
        print("select stack arch")
    elif args.arch == 't':
        from models.module_t import TransformerModuleNet
        from datasets.clevr_dataset_program_v2 import CLEVRDataset
        config.num_region = 74
        print("select tree arch")
    else:
        print("arch should be [s, t]")
        exit()

    model_path = args.from_pretrained if args.from_pretrained else path_cfgs.from_pretrained

    # set up model
    model = TransformerModuleNet(config)

    model.load_state_dict(torch.load(model_path))
    print(f'loaded from {model_path}')

    model.to(device)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
        print(f'use {n_gpu} GPUs')

    features_path_val = path_cfgs.path_dict_corpus_val[args.test + '_obj']
    annotation_path_val = path_cfgs.path_dict_annotation_val[args.test]
    print(f'test : {args.test}')

    validation_dataset = CLEVRDataset(
        features_path_val,
        annotation_path_val,
        path_cfgs.vocab_path,
        path_cfgs.func_vocab_path,
        path_cfgs.args_vocab_path,
        seq_len=36,
    )

    val_batch_size = 128
    gradient_accumulation_steps = 1
    
    validation_data_loader = DataLoader(validation_dataset, batch_size=val_batch_size, num_workers=2)

    numBatches = math.ceil(validation_dataset.num_dataset / val_batch_size / gradient_accumulation_steps)

    print(f'***** Running validation *****')
    print(f'Num examples = {len(validation_dataset)}')
    print(f'Batch size = {val_batch_size}')
    print(f'Num batches = {numBatches}')

    # Do the evaluation 

    torch.set_grad_enabled(False)
    start_t = timer()
    eval_total_loss = 0
    eval_total_matches = 0

    results_dict = {}

    model.eval()
    for step, batch in enumerate(validation_data_loader):
        batch = tuple(t.cuda(device=device, non_blocking=True) for t in batch)

        features, spatials, image_mask, arguments, answer_id, question_id = (
            batch
        )

        outputs, pred =  model(features, spatials, image_mask, arguments)

        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(pred, answer_id)

        if n_gpu > 1:
            loss = loss.mean() 
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        logits = torch.max(pred, 1)[1].data  # argmax
        count_matches = ((logits - answer_id) == 0).sum().float()

        # collect results for each questions
        matches = (logits - answer_id) == 0
        for idx, qid in enumerate(question_id):
            results_dict[str(qid.item())] = [matches[idx].item(), logits[idx].item(), answer_id[idx].item()]

        eval_total_matches += count_matches.item()

        if n_gpu > 1:
            loss = loss.mean()
        
        eval_total_loss += loss.item()

        end_t = timer()
        delta_t = " Time: %5.2fs" % (end_t - start_t)
        start_t = end_t
        progressString = "\r Evaluating split '%s' [%d/%d]\t" + delta_t
        sys.stdout.write(progressString % ('val', step + 1, numBatches))
        sys.stdout.flush()

    eval_total_loss = eval_total_loss / float(validation_dataset.num_dataset)
    eval_score = eval_total_matches / float(validation_dataset.num_dataset)

    printFormat = "Evaluation: [Loss: %.5g][Score: %.5g]"
    printInfo = [eval_total_loss, eval_score]

    print(printFormat % tuple(printInfo))
    print(f'{eval_total_matches} / {validation_dataset.num_dataset}')

    if args.dump:
        output_results_json = './results.json'
        with open(output_results_json, 'w') as f:
          json.dump(results_dict, f)
        print("dump the results", output)


if __name__ == "__main__":

    main()