from utils import *
from funcs import eval_snn
import argparse
from PreProcess import GetCifar10, GetCifar100
from Models.ResNet import *
from Models.VGG import *

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset name')
parser.add_argument('--datadir', type=str, default='../datasets', help='Dataset location')
parser.add_argument('--savename', type=str, default='MyModel', help='Model saving name')
parser.add_argument('--device', type=str, default='cuda:0', help='Device')
parser.add_argument('--batchsize', type=int, default=50, help='Batch size')
parser.add_argument('--T', type=int, default=512, help='Simulation length')

args = parser.parse_args()

# get model
model = torch.load(args.savename + '.pkl')

model = replace_activation_by_neuron(model)
search_fold_and_remove_bn(model)

# get data
if args.dataset.lower() == 'cifar10':
    _, test = GetCifar10(args.datadir, args.batchsize)
elif args.dataset.lower() == 'cifar100':
    _, test = GetCifar100(args.datadir, args.batchsize)
else:
    error('unable to find dataset ' + args.dataset)

# evaluating
acc = eval_snn(test, model, sim_len=args.T, device=args.device)

# print results
t = 2
while t < args.T:
    print(f'time step {t}, Accuracy {acc[t-1]/10000}')
    t *= 2
print(f'time step {args.T}, Accuracy {acc[args.T-1]/10000}')