import os 
import argparse
import time 
import datetime 
from torchvision import transforms, datasets
import torch
import torch.nn as nn

from architectures import get_architecture
from datasets import get_dataset, DATASETS, get_num_classes

from core import Smooth 
from DRM_sigma_est import DiffusionModel


CIFAR10_DATA_DIR = "data/cifar10"

def main(args):
    denoiser = DiffusionModel(args.diffusion_path)

    checkpoint = torch.load(args.base_classifier)

    num_classes = len(args.sigma_cand)

    classifier = get_architecture(checkpoint["arch"], 'cifar10', num_classes, False)

    sigma_cand_str = "_".join([f"%.3f" % sigma for sigma in args.sigma_cand])

    classifier.load_state_dict(checkpoint['state_dict'])

    sigma_label_path = os.path.join(args.sigma_label_dir, f"{sigma_cand_str}_{args.split}.npy")
    dataset = get_dataset('cifar10', f"{args.split}_sigma_est", sigma_label_path)

    # Get the timestep t corresponding to noise level sigma
    target_sigma = args.sigma * 2
    real_sigma = 0
    t = 0
    while real_sigma < target_sigma:
        t += 1
        a = denoiser.diffusion.sqrt_alphas_cumprod[t]
        b = denoiser.diffusion.sqrt_one_minus_alphas_cumprod[t]
        real_sigma = b / a

    certify_model = CertifyModel(denoiser, classifier)

    smoothed_classifier = Smooth(certify_model, num_classes, args.sigma, t)

    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    f = open(args.outfile, 'w')
    print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True)

    total_num = 0
    correct = 0
    for i in range(args.start, min(args.end, len(dataset))):
        if i % args.skip != 0:
            continue

        (x, label, _, _) = dataset[i]
        x = x.cuda()

        before_time = time.time()
        with torch.cuda.amp.autocast():
            prediction, radius = smoothed_classifier.certify(x, args.N0, args.N, args.alpha, args.batch_size)
        after_time = time.time()

        correct += int(prediction == label)

        time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
        total_num += 1

        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(
           i, label, prediction, radius, correct, time_elapsed), file=f, flush=True)

    print("sigma %.2f accuracy of smoothed classifier %.4f "%(args.sigma, correct/float(total_num)))


class CertifyModel(nn.Module):
    def __init__(self, denoiser, classifier):
        super().__init__()
        self.denoiser = denoiser
        self.classifier = classifier

        self.denoiser.eval().cuda()
        self.classifier.eval().cuda()

    def forward(self, x, t):
        imgs, inputs_noisy = self.denoiser(x, t)
        
        out = self.classifier(imgs)
            
        return out


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Predict on many examples')
    parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier")
    
    parser.add_argument("sigma", type=float, help="noise hyperparameter")
    parser.add_argument("outfile", type=str, help="output file")
    parser.add_argument("--diffusion_path", type=str, help="path to diffusion model",
                    default="models/diffusion/cifar10_uncond_50M_500K.pt")
    parser.add_argument("--skip", type=int, default=10, help="how many examples to skip")
    parser.add_argument("--start", type=int, default=0, help="starting index")
    parser.add_argument("--end", type=int, default=50000, help="ending index")
    parser.add_argument('--sigma_cand', type=float, nargs='+', default=[0.25, 0.5, 1.0],
                    help='sigma candidates')
    parser.add_argument("--N0", type=int, default=100, help="number of samples to use")
    parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
    parser.add_argument("--batch_size", type=int, default=1000, help="batch size")
    parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")
    parser.add_argument("--split", type=str, default="test", help="train or test set")
    parser.add_argument('--sigma_label_dir', type=str, default='data/sigma_label/base',
                help='suffix for dataset name')
    parser.add_argument('--trainset_suffix', type=str, default='',
                    help='suffix for dataset name')
    args = parser.parse_args()

    main(args)