# evaluate a smoothed classifier on a dataset
import argparse
from email.mime import base
import os, re, json
import setGPU
from datasets import get_dataset, DATASETS, get_num_classes
from core import Smooth
from time import time
import torch
import datetime
from architectures import get_architecture
from tqdm import trange, tqdm
import numpy as np
import scipy
import pandas as pd

import sys
sys.path.append('../')


from kWTA import models
from kWTA import activation
from kWTA import attack
from kWTA import training
from kWTA import utilities
from kWTA import densenet
from kWTA import resnet
from kWTA import wideresnet

parser = argparse.ArgumentParser(description='Evaluate models')
parser.add_argument("--dataset", default='cifar10', choices=DATASETS, help="which dataset")
parser.add_argument("--base_classifier", default=None, type=str, help="path to saved pytorch model of base classifier")
# parser.add_argument("sigma", type=float, help="noise hyperparameter")
# parser.add_argument("outfile", type=str, help="output file")
# parser.add_argument("--infty", action="store_true", help="use infty norm")
# parser.add_argument("--num_basis", default=1, type=int, help="number of basis vectors")
# parser.add_argument("--num_bins", default=1, type=int, help="number of bins per basis vectors")
# parser.add_argument("--num_pgd_iterations", default=100, type=int, help="number of PGD iterations")
parser.add_argument("--batch", type=int, default=1000, help="batch size")
# parser.add_argument("--skip", type=int, default=1, help="how many examples to skip")
# parser.add_argument("--max", type=int, default=-1, help="stop after this many examples")
parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set")
# parser.add_argument("--N0", type=int, default=100)
# parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
# parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")
args = parser.parse_args()

# for model_file in ["models/cifar10/resnet110/noise_0.00/checkpoint.pth.tar"]:

dataset = get_dataset(args.dataset, args.split)
for model_tuple in [("models/cifar10/resnet110/noise_0.00/checkpoint.pth.tar", "resnet110_noise0.00"), 
                   ("models/cifar10/resnet110/noise_0.12/checkpoint.pth.tar", "resnet110_noise0.12"), 
                   ("models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar", "resnet110_noise0.25"), 
                   ("models/cifar10/resnet110/noise_0.50/checkpoint.pth.tar", "resnet110_noise0.50"), 
                   ("models/cifar10/resnet110/noise_1.00/checkpoint.pth.tar", "resnet110_noise1.00"), 
                   ("kWTA_models/resnet18_cifar.pth", "resnet18"), 
                   ("kWTA_models/resnet18_cifar_adv.pth", "resnet18_adv"),
                   ("kWTA_models/spresnet18_0.1_cifar.pth", "kWTA_spresnet18_0.1"),
                   ("kWTA_models/spresnet18_0.1_cifar_adv.pth", "kWTA_spresnet18_0.1_adv"),
                   ("kWTA_models/spresnet18_0.2_cifar.pth", "kWTA_spresnet18_0.2"),
                   ("kWTA_models/spresnet18_0.2_cifar_adv.pth", "kWTA_spresnet18_0.2_adv")]:
    
    model_file = model_tuple[0]
    model_name = model_tuple[1]
    # load the base classifier
    checkpoint = torch.load(model_file)
    if type(checkpoint) is dict and "arch" in checkpoint.keys(): #model is from original smoothing repo
        base_classifier = get_architecture(checkpoint["arch"], args.dataset)
        base_classifier.load_state_dict(checkpoint['state_dict'])
    else: #checkpoint is only state_dict
        if "resnet18_cifar" in model_file:
            base_classifier = resnet.ResNet18().cuda()
        elif "spresnet18_0.1_cifar" in model_file:
            base_classifier = resnet.SparseResNet18(sparsities=[0.1,0.1,0.1,0.1], sparse_func='vol').cuda()
        elif "spresnet18_0.2_cifar" in model_file:
            base_classifier = resnet.SparseResNet18(sparsities=[0.2,0.2,0.2,0.2], sparse_func='vol').cuda()
        base_classifier.load_state_dict(checkpoint)
    base_classifier.eval()

    # iterate through the dataset

    correct_list = []

    for i in range(0, len(dataset), args.batch):

        (x, label) = zip(*list([(dataset[i+j][0], dataset[i+j][1]) for j in range(args.batch)]))

        before_time = time()
        # certify the prediction of g around x
        x = torch.stack(x).cuda()
        y_hat = base_classifier(x).argmax(1).cpu().numpy()
        
        after_time = time()
        correct_list.extend(list(y_hat == label))
        # correct_list.append(correct)
        time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))

    #print model name and and accuracy
    print("Model: {} / Accuracy: {}".format(model_name, np.mean(correct_list)))

