import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
from vgg import vgg16, vgg16_bn, vgg19
from torchvision.models.vgg import vgg16 as cleanVGG16
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from dataset import *
from transform import pad_zero
from bias_optimizer import BiasSGD
from torch.optim import SGD
import global_time
import time
import sys
import math

def top5(y_pred, target):
    pred = torch.argsort(y_pred, dim=1, descending=True)
    pred5 = torch.index_select(pred, 1, torch.tensor([0,1,2,3,4],device=pred.device))
    res = 0
    
    for i in range(target.size(0)):
        res += (pred5[i] == target[i]).sum()
    return res

def main(is_blind, batch_size, load_epoch, train_loader, val_loader):
    print(len(train_loader), len(val_loader))
    # system settings
    device = torch.device("cuda:0")
    dtype  = torch.float32

    # get default model
    num_classes = 1000
    
    
    if is_blind:
        model = vgg16(num_classes=num_classes, device=device, dtype=dtype, use_bias=not is_blind)
        model.addNoise()
        #model.initialize_weights()
    else:
        model = vgg16(num_classes=num_classes,  device=device, dtype=dtype, use_bias=not is_blind)

    model.to(device)
    print(model)
    
    
    model.load_imageNet_weight()
    model.eval()

    mini_batch_size = batch_size



    top1_err = 0.0
    top5_err = 0.0
    print("validating")
    total_num = 0
    for image, target in tqdm(val_loader):
        image = image.to(device)
        target = target.to(device)

        if target.size(0) % 2 != 0:
            continue
        # padding zeros
        if is_blind:
            image = pad_zero(image)

        with torch.no_grad():
            y_pred = model(image)
        
        total_num += target.size(0)
        
    model.print_stat()
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--blind', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--load', type=int, default=-1)
    parser.add_argument('--mean', type=float, default=0.0)
    parser.add_argument('--std', type=float, default=1.0)

    args = parser.parse_args()
    is_blind = args.blind == 1
    batch_size = args.batch_size
    train_loader = data_loader(0, batch_size)
    val_loader = data_loader(0, batch_size, is_train=False)
    global_time.init()
    global_time.mean = args.mean
    global_time.std  = args.std
    if global_time.mean == 0 and global_time.std == 0:
        is_blind = False
    main(is_blind, batch_size, args.load, train_loader, val_loader)
