import os 
import argparse
import time 
import datetime 
from torchvision import transforms, datasets
import torch

from core import Smooth 
from DRM_classifier import DiffusionRobustModel


CIFAR10_DATA_DIR = "data/cifar10"

def main(args):
    model = DiffusionRobustModel(args.diffusion_path, args.vit_path)

    dataset = datasets.CIFAR10(CIFAR10_DATA_DIR, train=args.train_set, download=False, transform=transforms.ToTensor())

    # Get the timestep t corresponding to noise level sigma
    target_sigma = args.sigma * 2
    real_sigma = 0
    t = 0
    while real_sigma < target_sigma:
        t += 1
        a = model.diffusion.sqrt_alphas_cumprod[t]
        b = model.diffusion.sqrt_one_minus_alphas_cumprod[t]
        real_sigma = b / a

    smoothed_classifier = Smooth(model, 10, args.sigma, t)
    
    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    f = open(args.outfile, 'w')
    print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True)

    total_num = 0
    correct = 0
    for i in range(args.start, min(args.end, len(dataset))):
        print(i)
        if i % args.skip != 0:
            continue

        (x, label) = dataset[i]
        x = x.cuda()

        before_time = time.time()
        with torch.cuda.amp.autocast():
            prediction, radius = smoothed_classifier.certify(x, args.N0, args.N, args.alpha, args.batch_size)
        print(time.time() - before_time)
        after_time = time.time()

        correct += int(prediction == label)

        time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
        total_num += 1

        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(
           i, label, prediction, radius, correct, time_elapsed), file=f, flush=True)

    print("sigma %.2f accuracy of smoothed classifier %.4f "%(args.sigma, correct/float(total_num)))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Predict on many examples')
    parser.add_argument("--diffusion_path", type=str, help="path to diffusion model")
    parser.add_argument("--vit_path", type=str, help="path to vit model")
    parser.add_argument("--sigma", type=float, help="noise hyperparameter")
    parser.add_argument("--skip", type=int, default=10, help="how many examples to skip")
    parser.add_argument("--start", type=int, default=0, help="starting index")
    parser.add_argument("--end", type=int, default=50000, help="ending index")
    parser.add_argument("--N0", type=int, default=100, help="number of samples to use")
    parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
    parser.add_argument("--batch_size", type=int, default=1000, help="batch size")
    parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")
    parser.add_argument("--outfile", type=str, help="output file")
    parser.add_argument("--train_set", action='store_true')
    args = parser.parse_args()

    main(args)