# evaluate a smoothed classifier on a dataset
import torch
import datetime

import os
import math
import time
import copy
import torch
import numpy as np
import torch.nn as nn
import datetime
from tqdm import tqdm

from argparse import ArgumentParser

import sys

sys.path.append("enter your path")

import torch.distributions as D

from scipy.stats import norm

from multiquery_randomized_smoothing.src.dataset_utils import get_dataset
# from multiquery_randomized_smoothing.src.models import architectures
from multiquery_randomized_smoothing.src.train_utils import set_seed, get_save_directory_path, get_image_size
from multiquery_randomized_smoothing.src.models.single_query_arch import SINGLE_QUERY_ARCH
from multiquery_randomized_smoothing.src.models.two_query_arch import TWO_QUERY_ARCH

from statsmodels.stats.proportion import proportion_confint

def _lower_confidence_bound(NA: int, N: int, failure_prob: float) -> float:
    """ Returns a (1 - failure_prob) lower confidence bound on a bernoulli proportion.

    This function uses the Clopper-Pearson method.

    :param NA: the number of "successes"
    :param N: the number of total draws
    :param failure_prob: the confidence level
    :return: a lower bound on the binomial proportion which holds true w.p at least (1 - failure_prob) over the samples
    """
    return proportion_confint(NA, N, alpha=2 * failure_prob, method="beta")[0]

def certify(model, x, n_pred, n_cert, failure_prob, adv, image_size, batch_size):
    """
    """

    # assumes num_channels = 3
    image_dims = image_size * image_size * 3

    # get n_pred monte carlo samples for prediction
    counts_prediction = monte_carlo_predictions(model, x, n_pred, batch_size)
    prediction = counts_prediction.argmax().item()
    
    # get n_cert monte carlo samples for certification
    counts_estimation = monte_carlo_predictions(model, x, n_cert, batch_size)
    nA = counts_estimation[prediction].item()
    prob_lb = _lower_confidence_bound(nA, n_cert, failure_prob)

    if adv == "l1":
        pass
    elif adv == "l2":
        # mult_factor = (args.linf_pert * np.sqrt(image_dims)) / args.mu
        # radius = mult_factor * norm.ppf(prob_lb)
        radius = args.total_sigma * norm.ppf(prob_lb)
    elif adv == "linf":
        # mult_factor = (args.linf_pert * np.sqrt(image_dims)) / args.mu
        # l2_radius = mult_factor * norm.ppf(prob_lb)
        # radius = l2_radius / np.sqrt(image_dims)
        l2_radius = args.total_sigma * norm.ppf(prob_lb)
        radius = l2_radius / np.sqrt(image_dims)

    if prob_lb < 0.5:
        prediction = -1 # abstain

    return prediction, radius, prob_lb

def _count_arr(arr, length):
    counts = np.zeros(length, dtype=int)
    for idx in arr:
        counts[idx] += 1
    return counts

def monte_carlo_predictions(model, x, num, batch_size):
    """Sample the base classifier's prediction under noisy corruptions of the input x.

    :param x: the input [channel x width x height]
    :param num: number of samples to collect
    :param batch_size:
    :return: an ndarray[int] of length num_classes containing the per-class counts
    """

    with torch.no_grad():
        counts = np.zeros(args.num_classes, dtype=int)
        for _ in range(math.ceil(num / batch_size)):

            this_batch_size = min(batch_size, num)
            num -= this_batch_size

            input_batch = x.repeat((this_batch_size, 1, 1, 1))

            # get model predictions
            logging_trackers = {
                'mode': 'certify',
            }
            if args.mask_output == "penalty":
                outputs, _ = model(input_batch, logging_trackers)
            else:
                outputs = model(input_batch, logging_trackers)

            predictions = outputs.argmax(1)
            counts += _count_arr(predictions.cpu().numpy(), args.num_classes)
            
        return counts

if __name__=="__main__":
    argparser = ArgumentParser()
    argparser.add_argument("--seed", type=int, default=42, help="random seed")
    argparser.add_argument("--device", type=str, default="cuda")    
    argparser.add_argument("--dataset", type=str, default=None)
    argparser.add_argument("--num_classes", type=int, default=None)
    argparser.add_argument("--dataset_path", type=str, default=None)

    # dataset transformation params
    argparser.add_argument("--pad_size", type=int, help="amount of padding on single side of the image", 
                           default=0)
    argparser.add_argument("--num_image_locations", type=str, help="1, 2, 4, 8  or random number of positions the center image is placed in", default=None)
    argparser.add_argument("--background", type=str, help="background on which an image is padded: black or nature", default=None)

    # some train params
    argparser.add_argument("--base_classifier", type=str, default=None)

    # DP and model arch related args
    argparser.add_argument("--num_queries", type=int,
                           default=None, help="number of queries: 1 or 2")
    argparser.add_argument("--budget_split", type=str,
                           default=None, help="fixed or learnt")
    argparser.add_argument("--first_query_budget_frac", type=float,
                           default=None, help="used when budget split is fixed")
    argparser.add_argument("--first_query_with_mask", action='store_true',
                           help="learn an average mask in the first query")
    argparser.add_argument("--second_query_mask_model", type=str,
                           help="architecture used to output per-input mask in the second query", default="None")
    argparser.add_argument("--mask_output", type=str, help="whether to apply mask penalty or use a sigmoid layer at end of mask model: penalty or sigmoid")
    argparser.add_argument("--average_queries", action='store_true',
                           help="whether to average noisy images across queries")
    argparser.add_argument("--averaging_style", type=str, default=None,
                           help="use old or new averaging")
    argparser.add_argument("--mask_init", type=str, help="how to initialize the mask: random or identity", default=None)
    argparser.add_argument("--mask_recon", type=float, help="whether to add regulization for loss, mask will also learn to reconstruct the image, this is the corresponding coefficient", default=0)

    argparser.add_argument("--wm_across_channels", type=str,
                           help="whether to learn a different mask per channel: same or different", default=None)
    argparser.add_argument("--total_sigma", type=float, default=None, help="total sigma in our design (which is what we take vanilla sigma to be)")

    # certification time params
    argparser.add_argument("--cert_batch_size", type=int, default=100)
    argparser.add_argument("--n_pred", type=int, default=100, help="")
    argparser.add_argument("--n_cert", type=int, default=10000, help="")
    argparser.add_argument("--failure_prob", type=float, default=0.05, help="")
    argparser.add_argument("--adv", type=str, default="l2")
    argparser.add_argument("--skip", type=int, default=50)
    argparser.add_argument("--max", type=int, default=-1)
    
    argparser.add_argument("--run_description", type=str,
                           help="short compressed description to create a run directory", default=None)
    # in case some different model is being certified
    argparser.add_argument("--log_dir", type=str, default=None)
    
    args = argparser.parse_args()

    # set the seed
    set_seed(args.seed)

    if args.log_dir is None:
        log_dir = get_save_directory_path(args)
    else:
        log_dir = args.log_dir

    # prepare output file
    outfile = os.path.join(log_dir, "certification_log_"+str(args.n_cert)+".txt")
    f = open(outfile, 'w+')
    print("idx \t label \t predict \t "+args.adv+"_radius \t prob_lb \t correct \t time", file=f, flush=True)

    # load test dataset
    transform_params = {
        "pad_size": args.pad_size,
        "num_image_locations": args.num_image_locations,
        "background": args.background
    }
    test_dataset = get_dataset(dataset=args.dataset,
                               split='test',
                               path=os.path.join(
                                   args.dataset_path, args.dataset),
                               transform_params=transform_params)

    # create model
    if args.num_queries == 1:
        model = SINGLE_QUERY_ARCH(args).to(args.device)        
    else:
        model = TWO_QUERY_ARCH(args).to(args.device)
    # print(model)
        
    # get the checkpointed state
    model_sd_path = os.path.join(log_dir, "model_sd.pt")
    checkpoint = torch.load(model_sd_path)

    # load the saved model parameters
    model.load_state_dict(checkpoint)

    # change the transformation layer if needed
    # change the noise layer if needed
    
    # get image size
    image_size = get_image_size(args)
    
    # main certify loop
    model.eval()
    for i in tqdm(range(len(test_dataset))):

        # only certify every skip examples, and stop after max examples
        if i % args.skip != 0:
            continue
        if i == args.max:
            break

        (x, label) = test_dataset[i]

        before_time = time.time()
        x = x.to(device=args.device)

        # certify the prediction of g around x
        prediction, radius, prob_lb = certify(model,
                                              x,
                                              args.n_pred,
                                              args.n_cert, 
                                              args.failure_prob,
                                              args.adv,
                                              image_size,
                                              args.cert_batch_size) 

        after_time = time.time()
        correct = int(prediction == label)

        time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
        print("{} \t {} \t {} \t {:.3} \t {:3} \t {} \t {}".format(
            i, label, prediction, radius, prob_lb, correct, time_elapsed), file=f, flush=True)
    
        # print("Certified input {}".format(i))

    f.close()