import torch
import torch.nn as nn
import torchvision
import datetime
import torchvision.transforms as transforms

import numpy as np
import sys
import os

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.append(parentdir)
from math import ceil
import torch.utils.data as utils

from statsmodels.stats.proportion import proportion_confint
transform = transforms.Compose(
	[
	 transforms.ToTensor(),
	])



from scipy.stats import norm
import argparse
from utils.Certified.architectures import ARCHITECTURES, get_architecture
from utils.Certified.datasets import DATASETS, get_dataset, get_num_classes
parser = argparse.ArgumentParser(description='certify many examples')
parser.add_argument("dataset", choices=DATASETS, help="which dataset")
parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier")
parser.add_argument("sigma", type=float, help="noise hyperparameter")
parser.add_argument("outfile", type=str, help="output file")
parser.add_argument("--batch", type=int, default=400, help="batch size")
parser.add_argument("--skip", type=int, default=20, help="how many examples to skip")
args = parser.parse_args()

if (args.dataset == "cifar10"):
	testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transform)
else:
	testset = torchvision.datasets.MNIST(root = './data', train = False, download = True, transform = transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle = False)
models = []
for i in range(3):
	checkpoint = torch.load(args.base_classifier + ".%d" % (i))
	model = get_architecture(checkpoint["arch"], args.dataset)
	
	model = nn.DataParallel(model).cuda()
	model.load_state_dict(checkpoint['state_dict'])
	model.eval()
	models.append(model)


correct, total = 0, 0

sigma = args.sigma

def _sample_noise(net, x, num, batch_size):
	counts = np.zeros(10, dtype=int)
	#avgf, N = 0, 0
	softmax = nn.Softmax(1)
	for _ in range(ceil(num / batch_size)):
		this_batch_size = min(batch_size, num)
		num -= this_batch_size
		batch = x.repeat((this_batch_size, 1, 1, 1))
		noise = torch.randn_like(batch, device='cuda') * sigma

		X = batch + noise
		feature = torch.zeros((this_batch_size, 3))
		pred = torch.zeros((this_batch_size, 3))
		for j in range(3):
			output = softmax(net[j](X))
			pred[:,j] = (output.sort()[1])[:,-1]
			output = output.sort()[0]
			feature[:,j] = output[:,-1] - output[:,-2]

		feature = (feature.sort()[1])[:,-1]
		predictions = pred[torch.arange(pred.size(0)),feature].type(torch.LongTensor)
		counts += _count_arr(predictions.cpu().numpy(), 10)
	return counts

def _count_arr(arr, length):
	counts = np.zeros(length, dtype=int)
	for idx in arr:
		counts[idx] += 1
	return counts

def confidence_bound(na, n, alp):
	return proportion_confint(na, n, alpha = 2 * alp, method="beta")[0]

def certify(net, x, n0, n, alp, batch_size, truth):
	count = _sample_noise(net, x, n0, batch_size)
	ca = count.argmax().item()
	est = _sample_noise(net, x, n, batch_size)
	nA = est[ca].item()
	pA = confidence_bound(nA, n, alp)
	
	if (pA < 0.5):
		ca = -1
	
	rad = sigma * norm.ppf(pA)
	return ca, rad

cur = 0
radius = []
correct = [0., 0., 0.]

from time import time
import os
outdir = os.path.dirname(args.outfile)
if not os.path.exists(outdir):
	os.makedirs(outdir)
f = open(args.outfile, 'w')
print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True)

for i, data in enumerate(testloader):
	if (i % args.skip != 0): continue
	X, y = data
	X, y = X.cuda(), y.cuda()
	
	before_time = time()
	prediction, radius = certify(models, X[0], 100, 100000, 0.001, args.batch, y[0].item())
	after_time = time()
	label = y[0].item()

	correct = int(prediction == label)
	time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
	print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(
		i, label, prediction, radius, correct, time_elapsed), file=f, flush=True)

f.close()
