import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
from vgg import vgg16, vgg19
from torch.nn import CrossEntropyLoss
from torchvision.models.vgg import vgg16 as normal_vgg
from mobilenet import mobilenet_v2
from tqdm import tqdm
from dataset import *
from transform import pad_zero, get_model, getdata_set_name, num_class
from parameters import get_matrix, identity_matrix
import global_time
import time
import sys

def top5(y_pred, target):
    pred = torch.argsort(y_pred, dim=1, descending=True)
    pred5 = torch.index_select(pred, 1, torch.tensor([0,1,2,3,4],device=pred.device))
    res = 0
    
    for i in range(target.size(0)):
        res += (pred5[i] == target[i]).sum()
    return res

def main(is_blind, batch_size, val_loader, model_name):

    # system settings
    device = torch.device("cuda:0")
    dtype  = torch.float32

    # get models
    num_classes = 1000
    model = get_model(model_name, is_blind, device, dtype, num_classes, pretrained=True)

    model.to(device)
    print(model)
    
   
    model.eval()
    top1_err = 0.0
    top5_err = 0.0
    total_size = 0
    with torch.no_grad():
        for image, target in tqdm(val_loader):
            image = image.to(device)
            target = target.to(device)

            # padding zeros
            if is_blind:
                image = pad_zero(image)
            
            total_size += target.size(0)

            y_pred = model(image)
            
            if is_blind:
                y_pred = y_pred[0:+target.size(0)]

            # top five
            top1 = torch.argmax(y_pred, dim=1)
            diff1 = sum(((top1 - target) == 0).float())
            diff5 = top5(y_pred, target)
            top1_err += diff1
            top5_err += diff5
            
        acc1 = top1_err / total_size
        acc5 = top5_err / total_size
        print("top 1 acc ", acc1.item(), "top 5 acc ", acc5)
        print(global_time.mean, global_time.std)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--blind', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--mean', type=float, default=0.0)
    parser.add_argument('--std', type=float, default=1.0)
    parser.add_argument("-m", "--model", required=True, help="model name")


    args = parser.parse_args()

    is_blind = args.blind == 1
    batch_size = args.batch_size
    global_time.init()
    global_time.mean = args.mean
    global_time.std  = args.std
    if global_time.mean == 0 and global_time.std == 0:
        is_blind = False

    model_name = args.model
    val_loader = data_loader(0, batch_size, is_train=False)

    main(is_blind, batch_size, val_loader, model_name)
