# encoding: utf-8
import argparse
import os
import logging
import math
import time
import sys
from tqdm import tqdm
import torch

from network import create_network
from utils import load_checkpoint, get_val_loader
from collections import OrderedDict


class AvgMeter(object):
    """
    Computing mean
    """

    name = "No name"

    def __init__(self, name="No name", fmt=":.2f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.sum = 0
        self.mean = 0
        self.num = 0
        self.now = 0

    def update(self, mean_var, count=1):
        if math.isnan(mean_var):
            mean_var = 1e6
            print("Avgmeter getting Nan!")
        self.now = mean_var
        self.num += count

        self.sum += mean_var * count
        self.mean = float(self.sum) / self.num

    def __str__(self):
        print_str = self.name + "-{" + fmt + "}"
        return print_str.format(self.mean)


def run_eval(ds_val, max_iters, net):
    net.eval()
    pbar = tqdm(range(max_iters))
    top1 = AvgMeter()
    top5 = AvgMeter()
    losses = AvgMeter()
    pbar.set_description("Validation")
    with torch.no_grad():
        for i in pbar:
            start_time = time.time()
            pbar_dic = OrderedDict()
            data, label = ds_val.next()

            data = data.cuda()
            label = label.type(torch.long).cuda()
            data_time = time.time() - start_time
            loss_dict, output_dict = net(data, label)
            loss = loss_dict["loss"]
            err, err5 = output_dict["Err1"], output_dict["Err5"]

            top1.update(err)
            top5.update(err5)
            losses.update(loss.item())
            pbar_dic["data-time"] = "{:.2f}".format(data_time)
            pbar_dic["top1-err"] = "{:.2f}".format(top1.mean)
            pbar_dic["top5-err"] = "{:.2f}".format(top5.mean)
            pbar_dic["loss"] = "{:.5f}".format(losses.mean)
            pbar.set_postfix(pbar_dic)

    metric_dic = {
        "top1-err": torch.tensor(top1.mean),
        "top5-err": torch.tensor(top5.mean),
        "loss": torch.tensor(losses.mean),
    }

    return metric_dic


def test_per_iter(ds_val, max_iters, net, iters):
    logging.info("Validation for model saved in iter:{}".format(iters))
    result_dic = run_eval(ds_val, max_iters, net)
    log_str = "val at it:{:d}: top1-err {:.2f}, top5-err {:.2f}, loss {:.5f}.".format(
        iters, result_dic["top1-err"], result_dic["top5-err"], result_dic["loss"]
    )
    logging.info(log_str)

    return result_dic


def main():
    parser = argparse.ArgumentParser(description="PyTorch ClassiFication Inference")
    parser.add_argument("--output_dir", type=str, default="./output")
    parser.add_argument("--batch_size", type=int, default=200)
    parser.add_argument("--eval_dataset_dir", type=str)
    parser.add_argument("--iter", type=int, help="The iteration number for testing)",
    )
    args = parser.parse_args()

    log_format = "[%(asctime)s] %(message)s"
    logging.basicConfig(
        stream=sys.stdout, level=logging.INFO, format=log_format, datefmt="%d %I:%M:%S"
    )
    t = time.time()
    local_time = time.localtime(t)
    fh = logging.FileHandler(f"{args.output_dir}/test_log.txt")
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(args)

    args.model = create_network()
    args.device = torch.device("cuda")
    args.model.to(args.device)
    ds_val = get_val_loader(args.eval_dataset_dir, args.batch_size)

    model_file = os.path.join(args.output_dir, f"iter-{args.iter}.pth")
    if os.path.exists(model_file):
        logging.info(
            f"\n\n***start to evaluate iteration of {args.iter}***"
        )
        load_checkpoint(args, model_file)
        test_per_iter(ds_val, 50000//args.batch_size, args.model, args.iter)


if __name__ == "__main__":
    main()
    os._exit(0)
