# evaluate a smoothed classifier on a dataset
import argparse
import os, re, json
from datasets import get_dataset, DATASETS, get_num_classes
from core import Smooth
from time import time
import torch
import datetime
from architectures import get_architecture
from tqdm import trange, tqdm
import numpy as np
import scipy
import pandas as pd
import hashlib

import sys
sys.path.append('../')


from kWTA import models
from kWTA import activation
from kWTA import attack
from kWTA import training
from kWTA import utilities
from kWTA import densenet
from kWTA import resnet
from kWTA import wideresnet

import json

parser = argparse.ArgumentParser(description='Certify many examples')
parser.add_argument("dataset", choices=DATASETS, help="which dataset")
parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier")
parser.add_argument("sigma", type=float, help="noise hyperparameter")
parser.add_argument("outfile", type=str, help="output file")
parser.add_argument("--infty", action="store_true", help="use infty norm")
parser.add_argument("--diffusion_defense", action="store_true", help="use diffusion defense")
parser.add_argument("--timestep_respacing", default=None, type=str, help="timestep respacing list for diffusion defense")
parser.add_argument("--num_basis", default=1, type=int, help="number of basis vectors")
parser.add_argument("--num_bins", default=1, type=int, help="number of bins per basis vectors")
parser.add_argument("--num_pgd_iterations", default=100, type=int, help="number of PGD iterations")
parser.add_argument("--batch", type=int, default=1000, help="batch size")
parser.add_argument("--skip", type=int, default=1, help="how many examples to skip")
parser.add_argument("--start", type=int, default=0, help="where to start with examples")
parser.add_argument("--max", type=int, default=-1, help="stop after this many examples")
parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set")
parser.add_argument("--N0", type=int, default=100)
parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")
parser.add_argument("--jem_defense", type=int, default=-1, help="use jem defense and refine this many times")
parser.add_argument("--num_corrupt", type=int, default=None, help="number of corrupted noise patterns to use")
args = parser.parse_args()

if args.timestep_respacing:
    args.timestep_respacing = json.loads(args.timestep_respacing)

if __name__ == "__main__":
    # load the base classifier
    checkpoint = torch.load(args.base_classifier)

    # use jem defense if requested
    if args.jem_defense > -1:
        # import jem defense from JEM_for_smoothing repo
        sys.path.append('JEM_for_smoothing/')
        from JEM_for_smoothing.attack_model import CCF, DummyModel, gradient_attack_wrapper
        import torch.nn as nn
        
        f = CCF(28, 10, None)
        if "model_state_dict" in checkpoint:
            # loading from a new checkpoint
            f.load_state_dict(checkpoint["model_state_dict"])
        else:
            # loading from an old checkpoint
            f.load_state_dict(checkpoint)
        
        if args.jem_defense > 0:
            f = DummyModel(f)#, n_steps_refine=args.jem_defense)
        elif args.jem_defense == 0:
            f = DummyModel(f, n_steps_refine=args.jem_defense, n_dup_chains=1, sigma=0., sgld_sigma=0.)
        else:
            raise ValueError("jem defense must be >= 0")
            
        base_classifier = f
        base_classifier.to('cuda')
        base_classifier = nn.DataParallel(base_classifier).to('cuda')
        base_classifier.eval()
        base_classifier = gradient_attack_wrapper(base_classifier, True) # make deterministic
        args.outfile = os.path.join(os.path.dirname(args.outfile), "0refine_jem_" + os.path.basename(args.outfile))


    elif type(checkpoint) is dict and "arch" in checkpoint.keys(): #model is from original smoothing repo
        base_classifier = get_architecture(checkpoint["arch"], args.dataset)
        base_classifier.load_state_dict(checkpoint['state_dict'])
        args.outfile = os.path.join(os.path.dirname(args.outfile), checkpoint["arch"].replace("cifar_","") + os.path.basename(args.base_classifier).replace('.pth','').replace('checkpoint.tar','') + '_' + os.path.basename(args.outfile))
    else: #checkpoint is only state_dict
        if "resnet18_cifar" in args.base_classifier:
            base_classifier = resnet.ResNet18().cuda()
        elif "spresnet18_0.1_cifar" in args.base_classifier:
            base_classifier = resnet.SparseResNet18(sparsities=[0.1,0.1,0.1,0.1], sparse_func='vol').cuda()
        elif "spresnet18_0.2_cifar" in args.base_classifier:
            base_classifier = resnet.SparseResNet18(sparsities=[0.2,0.2,0.2,0.2], sparse_func='vol').cuda()
        base_classifier.load_state_dict(checkpoint)
    base_classifier.eval()

    # use diffusion defense if requested
    if args.diffusion_defense:
        from diffusion.defense import Diffusion_Defense_Model
        base_classifier = Diffusion_Defense_Model(base_classifier, timestep_respacing=args.timestep_respacing)
        args.outfile = os.path.join(os.path.dirname(args.outfile), "diff_" + os.path.basename(args.outfile))
    
    
    # create the smooothed classifier g
    smoothed_classifier = Smooth(base_classifier, get_num_classes(args.dataset), args.sigma)

    # prepare output file
    if 'kWTA' in args.base_classifier:
        args.outfile = os.path.join(os.path.dirname(args.outfile), 'kWTA_' + os.path.basename(args.base_classifier).replace('.pth','') + '_' + os.path.basename(args.outfile))
    if args.infty:
        args.outfile = args.outfile.replace(".json", "_infty.json")
    
    replace_spaces_with_comma = lambda x: re.sub(r'\s+', ', ', re.sub(r'\[\s+', '[', x))

    if os.path.exists(args.outfile):
        #read into pandas with json
        df = pd.read_json(args.outfile, orient='records')
    else:
        df = pd.DataFrame(columns=["idx", "label", "inference", "bruteforce", "bruteforce_loss", "sampled", "sampled_loss", "pgd", "pgd_loss", "time"])
    df.index = df.idx

    # iterate through the dataset
    dataset = get_dataset(args.dataset, args.split)
        

    # set random seet for np based on dim and bin
    np.random.seed(args.num_basis*int(1e6) + args.num_bins)

    # create orthogonal basis vectors with numpy
    basis_vectors = np.random.randn(args.num_basis, *dataset[0][0].size())
    basis_vectors = basis_vectors.reshape(args.num_basis, -1)
    basis_vectors = scipy.linalg.orth(basis_vectors.T).T
    basis_vectors = basis_vectors / np.linalg.norm(basis_vectors, axis=1, keepdims=True)

    # verify the orthonormality of the basis vectors
    assert np.all(np.dot(basis_vectors, basis_vectors.T).round(13) == np.eye(args.num_basis))

    basis_vectors = basis_vectors.reshape(args.num_basis, *dataset[0][0].size())

    if len(df['idx']) > 1:
        indices = sorted(list(set(list(range(args.start, len(dataset) if args.max == -1 else args.max, args.skip))) - set(df['idx'].values)))
        indices = tqdm(indices)
    else:
        indices = trange(args.start, len(dataset) if args.max == -1 else args.max, args.skip)

    
    # if tqdm indices is empty, then we are done
    if len(indices) == 0:
        print("Results already exist in the output file.")
        quit()

    for i in indices:
        (x, label) = dataset[i]
        
        before_time = time()
        # certify the prediction of g around x
        x = x.cuda()
        
        # compare PGD with brute force discrete space
        if args.timestep_respacing:
            pgd_batch_size = min(args.batch, max(1,int(50//len(args.timestep_respacing))))
        else:
            pgd_batch_size = args.batch


        
        seed = 1337 #int(hashlib.sha256(x.cpu().numpy().tobytes()).hexdigest(), 16) % (2**32 - 1)
        torch.manual_seed(seed)
        num_corrupt = args.num_corrupt

        if num_corrupt:
            corruption_noise = torch.randn_like(x.repeat((num_corrupt, 1, 1, 1)), device='cuda') 
            corruption_noise = corruption_noise / torch.norm(corruption_noise.view(corruption_noise.shape[0], -1), dim=1, keepdim=True).view(-1, 1, 1, 1)
            corruption_noise = corruption_noise * args.sigma
            corruption_noise_numpy = corruption_noise.cpu().numpy()
        else:
            corruption_noise = None
            corruption_noise_numpy = None
        
        if corruption_noise is None:
            logits = base_classifier(x.unsqueeze(0))
            y_hat = logits.argmax(1)
        else:
            logits = base_classifier(torch.clamp(x  + corruption_noise,0.,1.))
            predictions = logits.argmax(1)
            predictions = predictions.reshape((-1, num_corrupt))
            # calculate the mode of the last axis of predictions
            y_hat = torch.mode(predictions, dim=1)[0]
        
        
        y_hat = base_classifier(x.unsqueeze(0)).argmax(1)

        smoothed_classifier.log_dir = os.path.dirname(args.outfile)
        print("running autoattack")
        pgd_counts, pgd_losses = smoothed_classifier._autoattack_with_basis(x, basis_vectors, -1., args.num_pgd_iterations, 20, pgd_batch_size, infty=args.infty, corruption_noise=corruption_noise)

        assert num_corrupt is None or np.all(corruption_noise_numpy == corruption_noise.cpu().numpy())
        
        print("running bruteforce")
        bruteforce_counts, bruteforce_maxloss = smoothed_classifier._bruteforce_discrete_basis(x, args.num_bins, basis_vectors, args.batch, infty=args.infty, corruption_noise=corruption_noise)

        num_bruteforce_samples = max(np.sum(bruteforce_counts), 1000)

        assert num_corrupt is None or np.all(corruption_noise_numpy == corruption_noise.cpu().numpy())

        print("running sampled")
        sampled_counts, sampled_maxloss = smoothed_classifier._sample_discrete_noise(x, num_bruteforce_samples, basis_vectors, args.batch, infty=args.infty, corruption_noise=corruption_noise)

        after_time = time()

        time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))


        #add row to dataframe at idx
        df.loc[i] = [i, label, y_hat.cpu().numpy()[0], bruteforce_counts, bruteforce_maxloss, sampled_counts, sampled_maxloss, pgd_counts, pgd_losses, time_elapsed]

        df.to_json(args.outfile, orient='records')
