import os 
import argparse
import datetime
from time import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

from scipy.stats import norm, binomtest
import numpy as np
from math import ceil
from statsmodels.stats.proportion import proportion_confint

from tqdm import tqdm

from architectures import get_architecture
from datasets import get_dataset, DATASETS, get_num_classes

from DRM_sigma_est import DiffusionModel

parser = argparse.ArgumentParser(description='Predict on many examples')
parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier")
parser.add_argument("sigma", type=float, help="noise hyperparameter")
parser.add_argument("outfile", type=str, help="output file")
parser.add_argument("--diffusion_path", type=str, help="path to diffusion model",
                    default="models/diffusion/cifar10_uncond_50M_500K.pt")
parser.add_argument("--batch_size", type=int, default=200, help="batch size")
parser.add_argument("--skip", type=int, default=1, help="how many examples to skip")
parser.add_argument("--max", type=int, default=-1, help="stop after this many examples")
parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set")
parser.add_argument("--N", type=int, default=100, help="number of samples to use")
parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")

parser.add_argument('--sigma_cand', type=float, nargs='+', default=[0.25, 0.5, 1.0],
                    help='sigma candidates')
args = parser.parse_args()

if args.batch_size % args.N != 0:
    raise ValueError("batch_size must be a multiple of N")

class CertifyModel(nn.Module):
    def __init__(self, denoiser, classifier):
        super().__init__()
        self.denoiser = denoiser
        self.classifier = classifier

        self.denoiser.eval().cuda()
        self.classifier.eval().cuda()


    def forward(self, x, t):
        imgs, inputs_noisy = self.denoiser(x, t)

        out = self.classifier(imgs)
            
        return out


class Smooth(object):
    """A smoothed classifier g """

    # to abstain, Smooth returns this int
    ABSTAIN = -1

    def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float, t: int):
        self.base_classifier = base_classifier
        self.num_classes = num_classes
        self.sigma = sigma
        self.t = t 

    def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:
        self.base_classifier.eval()
        counts = self._sample_noise(x, n, batch_size)
        # print(counts)
        top = np.zeros(len(x), dtype=int)

        for i, c in enumerate(counts):
            top2 = c.argsort()[::-1][:2]
            count1 = c[top2[0]]
            count2 = c[top2[1]]

            top[i] = top2[0]

        return top

    def _sample_noise(self, x: torch.tensor, num: int, batch_size) -> np.ndarray:
        with torch.no_grad():
            batch = x.repeat((num, 1, 1, 1))
            predictions = self.base_classifier(batch, self.t).argmax(1).view(num, -1)

            counts = np.zeros((len(x), self.num_classes), dtype=int)

            for i in range(len(x)):
                p = predictions[:, i]
                counts[i] = self._count_arr(p.cpu().numpy(), self.num_classes)
            return counts

    def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray:
        counts = np.zeros(length, dtype=int)
        for idx in arr:
            counts[idx] += 1
        return counts

if __name__ == "__main__":
    # load the base classifier
    denoiser = DiffusionModel(args.diffusion_path)
    checkpoint = torch.load(args.base_classifier)

    num_classes = len(args.sigma_cand)

    classifier = get_architecture(checkpoint["arch"], 'cifar10', num_classes, False)
    
    sigma_cand_str = "_".join([f"%.3f" % sigma for sigma in args.sigma_cand])

    classifier.load_state_dict(checkpoint['state_dict'])

    dataset = get_dataset('cifar10', "_".join([args.split, sigma_cand_str]))

    img_batch_size = args.batch_size / args.N

    dataloader = DataLoader(
        dataset,
        batch_size=int(img_batch_size),
        shuffle=False,
        num_workers=4
    )

    target_sigma = args.sigma * 2
    real_sigma = 0
    t = 0
    while real_sigma < target_sigma:
        t += 1
        a = denoiser.diffusion.sqrt_alphas_cumprod[t]
        b = denoiser.diffusion.sqrt_one_minus_alphas_cumprod[t]
        real_sigma = b / a

    certify_model = CertifyModel(denoiser, classifier)


    # create the smoothed classifier g
    smoothed_classifier = Smooth(certify_model, num_classes, args.sigma, t)

    # prepare output file
    f = open(args.outfile, 'w')
    print("idx\tlabel\tpredict\tcorrect\ttime", file=f, flush=True)

    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    f = open(args.outfile, 'w')
    print("idx\tlabel\tpredict\tcorrect\ttime", file=f, flush=True)

    total_num = 0
    correct = 0

    # iterate through the dataset
    for x, label, _, _ in tqdm(dataloader):
        x = x.cuda()
        before_time = time()

        # make the prediction
        prediction = smoothed_classifier.predict(x, args.N, args.alpha, args.batch_size)

        after_time = time()
        # correct += int(prediction == label)

        time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))

        # log the prediction and whether it was correct
        

        for i, p in enumerate(prediction):
            print("{}\t{}\t{}\t{}\t{}".format(
               total_num, label[i], p, correct, time_elapsed), file=f, flush=True)
            total_num += 1

    f.close()

    df = pd.read_csv(args.outfile, sep="\t")

    predict = df["predict"].values

    np.save(args.outfile.replace(".tsv", ".npy"), predict)
