from torchvision.models.resnet import ResNet, Bottleneck
from torch.cuda.amp import autocast, GradScaler

from torch.nn.parallel import DistributedDataParallel as DDP


############################################################################3

import argparse
import os

import sys

#import setGPU
#from time import time
#import time as basetime
import time
import datetime

#from tqdm import tqdm

from statsmodels.stats.proportion import multinomial_proportions_confint as multi_conf
from statsmodels.stats.proportion import proportion_confint as binom_conf

import torch

from scipy.stats import norm, binom_test
import numpy as np
from math import ceil

from collections import OrderedDict

import torch.backends.cudnn as cudnn

import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
from torchvision import datasets
import torchvision.transforms as transforms

from statsmodels.stats.proportion import proportion_confint as binom_conf
from statsmodels.stats.proportion import multinomial_proportions_confint as multi_conf

###############################################################
# Utilities
###############################################################

def at_radius(quantity, radius, count, flag_upper=True):
    if flag_upper:
        return np.sum((quantity >= radius)) / count
    else:
        return np.sum((quantity <= radius)) / count


def basic_predict(model, images, device):
    images = images.detach().clone()
                    
    _, _, pred_baseline = return_wrapper(model, images, device)
        
    indices, counts = torch.unique(pred_baseline, sorted=True, return_counts=True)
    pred_baseline = indices[torch.argmax(counts)]
    
    return pred_baseline.detach()
    
def return_wrapper(model, inpt, device, validate=False):
    samples = inpt.shape[0]
    
    flag = False
    if inpt.shape[1] == 1:
        flag = True
        inpt = inpt.repeat(1, 3, 1, 1)
    
    if validate:
        direct_out = model(inpt)       
        ix = torch.argmax(direct_out, dim=1)
        indices, counts = torch.unique(ix, return_counts=True)
        return indices, counts, ix
    else:
        x = inpt.detach().clone()
        gs = F.gumbel_softmax(100*model(inpt), tau=1, hard=True, dim=1) # wrapper_module(model, inpt)-
        sf = torch.sum(gs, dim=0).reshape(1, -1)

        if flag:
            x = x[:,0,:,:].unsqueeze(1)
        return (sf[0,:] / samples).to(device), x, (torch.argmax(gs, dim=1)).to(device)
        
def pred_classes_to_counts_v2(pred_classes, samples): 
    pred_classes = pred_classes[0,:]
    
    sort, ind = torch.sort(pred_classes)
    max_class, second_class = ind[-1], ind[-2]
    max_count = int(sort[-1]*samples)
    second_count = int(sort[-2]*samples)    
               
    return max_class.item(), second_class.item(), max_count, second_count  
    
def torch_ci(counts, samples, alpha):
    norm = torch.distributions.Normal(0,1)
    q = counts / samples
    std = torch.sqrt(q * (1 - q) / samples)
    dist = -norm.icdf(torch.tensor(alpha) / 2.) * std
    return q - dist, q + dist    
    
    
def norm_distance(a, b, flag):
    if flag:
        return torch.linalg.norm(a[:,0,:,:] - b[:,0,:,:])
    else:
        return torch.linalg.norm(a - b)
        
def calculate_statistics(final_result, first_result, final_time, first_time, success):
    
    if len(success) == len(final_result):
        # If these lens match, then ~success results won't work
        set_mean, set_median, time_mean, time_u_mean, time_total_mean = np.mean(final_result[success]), np.median(final_result[success]), np.mean(final_time[success]), 0, np.mean(final_time)
        set_mean_f, set_median_f, time_mean_f, time_u_mean_f, time_total_mean_f = np.mean(first_result[success]), np.median(first_result[success]), np.mean(first_time[success]), 0, np.mean(first_time)        
    if len(success) == 0:
        set_mean, set_median, time_mean, time_u_mean, time_total_mean = 0, 0, 0, np.mean(final_time), np.mean(final_time)
        set_mean_f, set_median_f, time_mean_f, time_u_mean_f, time_total_mean_f = 0, 0, 0, np.mean(first_time), np.mean(first_time)
    else:
        set_mean, set_median, time_mean, time_u_mean, time_total_mean = np.mean(final_result[success]), np.median(final_result[success]), np.mean(final_time[success]), np.mean(final_time[~success]), np.mean(final_time)
        set_mean_f, set_median_f, time_mean_f, time_u_mean_f, time_total_mean_f = np.mean(first_result[success]), np.median(first_result[success]), np.mean(first_time[success]), np.mean(first_time[~success]), np.mean(first_time)        
    return set_mean, set_median, time_mean, time_u_mean, time_total_mean, set_mean_f, set_median_f, time_mean_f, time_u_mean_f, time_total_mean_f, np.percentile(first_result, 90), np.percentile(final_result, 90)
    
############################################       
# Attacks
############################################

# CW-L2 Attack
def cw_l2_attack(model, images, labels, samples, sigma, device, targeted=False, c=1e-4, kappa=0, max_iter=100, learning_rate=0.01, printing=False, z_val=4., definitive=False):
    inpt_flag = True if images.shape[1] == 1 else False

    start_time = time.time()

    if definitive is not False:
        alpha = 0.5*(1 - stats_norm.cdf(z_val))

    device = images.device
    images = images.to(device)     
    labels = labels.to(device)

    # Define f-function
    def f(x, samples, sigma) :
        device = x.device
        adv_input = x.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma
        
        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters        

        outputs, _, _ = return_wrapper(model, adv_input, device)
        one_hot_labels = torch.eye(outputs.shape[0])[labels].to(device) 

        i, _ = torch.max((1-one_hot_labels)*outputs, dim=1)
        j = torch.masked_select(outputs, one_hot_labels.bool()) 
         
        if targeted :
            return torch.clamp(i-j, min=-kappa)
        
        else :
            return torch.clamp(j-i, min=-kappa)
    
    w = torch.zeros_like(images, requires_grad=True).to(device)

    optimizer = optim.Adam([w], lr=learning_rate)

    prev = 1e10
    
    flag = False
    min_val, min_recorded = 1e6, None
    
    for step in range(max_iter) :

        a = 1/2*(nn.Tanh()(w) + 1)

        loss1 = nn.MSELoss(reduction='sum')(a, images)
        loss2 = torch.sum(c*f(a, samples, sigma))

        cost = loss1 + loss2

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        # Early Stop when loss does not converge.
        if step % (max_iter//10) == 0 :
            if cost > prev :
                if printing:
                    print('Attack Stopped due to CONVERGENCE....')
                return a
            prev = cost
        if printing:
            print('- Learning Progress : %2.2f %%        ' %((step+1)/max_iter*100), end='\r')

        attack_images = 1/2*(nn.Tanh()(w) + 1)

        adv_input = attack_images.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma 
               
        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters        
            


        sm_output, x, pred_classes = return_wrapper(model, adv_input, device)
        vals, indices = torch.topk(sm_output, 2, sorted=True)

        max_class, second_class = indices[0], indices[1] 
        E0, E1 = vals[0], vals[1]
    
        if (max_class != labels):
            E_0 = binom_conf(E0.detach().cpu().numpy() * samples, samples, alpha=alpha, method='beta')[0]
            E_1 = binom_conf(E1.detach().cpu().numpy() * samples, samples, alpha=alpha, method='beta')[1] #1 - E_0
            E0, E1 = torch.tensor(E0).to(device), torch.tensor(E1).to(device)
            if E0 > E1:
                min_temp = norm_distance(attack_images, images, inpt_flag).detach().cpu().numpy() 
                if min_temp < min_val:
                    min_val = min_temp
                    final_cw_time = time.time() - start_time
                    min_recorded = attack_images.detach().clone()            
                    if flag is False:
                        first_cw = min_temp
                        first_time = time.time() - start_time                        
                        flag = True

        

    if flag is False:
        final_cw_time = time.time() - start_time
        first_cw = min_val
        first_time = final_cw_time


    return attack_images, flag, min_val, min_recorded, first_cw, first_time, final_cw_time
    
def direct_attack(model, yp, images, samples, class_set, z_val, min_recorded, sigma, device, inpt_flag, yp_best=None, lambda_val=0.5, delta=None, cutoff_step=0.5):
    stepsize = 0.05
    printing = False     
    device = yp.device
    
    relu = torch.nn.ReLU()
    norm = torch.distributions.Normal(0,1)
    
    alpha = 0.5*(1 - stats_norm.cdf(z_val)) #0.005

    yp.detach_()
    yp.requires_grad = True
    
    adv_input = yp.repeat(samples, 1, 1, 1)
    adv_input += torch.randn_like(adv_input) * sigma #torch.normal(mean=torch.zeros_like(adv_input), std=sigma)
    
    if adv_input.shape[1] != 3:
        adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters

    sm_output, x, pred_classes = return_wrapper(model, adv_input, device)
    
    if printing:
        print('#'*20, pred_classes.shape, yp.shape)
    
    vals, indices = torch.topk(sm_output, 2, sorted=True)
    
    max_class, second_class = indices[0], indices[1]

    E0, E1 = vals[0], vals[1] #vals[0]
    
    
    E0_t, E0_u_t = binom_conf(E0.detach().cpu().numpy() * samples, samples, alpha=alpha, method='beta')
    E1_l_t, E1_t = binom_conf(E1.detach().cpu().numpy() * samples, samples, alpha=alpha, method='beta') #1 - E_0
    E0_v, E1_v = E0.detach().cpu().numpy(), E1.detach().cpu().numpy()
    E0_t, E0_u_t = E0_t - E0_v, E0_u_t - E0_v
    E1_l_t, E1_t = E1_l_t - E1_v, E1_t - E1_v
    
    E0, E0_u, E1_l, E1 = E0 + torch.tensor(E0_t).to(device), E0 + torch.tensor(E0_u_t).to(device), E1 + torch.tensor(E1_l_t).to(device), E1 + torch.tensor(E1_t).to(device)

        
    cohen_val = relu(0.5*sigma*(norm.icdf(relu(E0)) - norm.icdf(relu(E1))))
     
    delta = 0 
    if max_class == class_set[0]:    
        objective_recorded = torch.abs(E0 - E1 + 0.1)
        
        stepsize = torch.min(torch.max(1.05*cohen_val, torch.tensor(0.05)), torch.tensor(cutoff_step))
        if printing:
            print(E0 - E1, '#'*5)
        
    else:
        if printing:
            print(E0 - E1, torch.linalg.norm(yp-images), lambda_val, delta, '!'*5)
        if E0 < E1: # Adversarial example but not confident, then still prioritise the adversarial example
            objective_recorded = 10*torch.abs(E0 - E1 - 0.05) + lambda_val * norm_distance(yp, images, inpt_flag) 
            stepsize = torch.tensor(cutoff_step / 2)
        else:            
            objective_recorded = torch.abs(E0 - E1) + lambda_val * norm_distance(yp, images, inpt_flag)
            stepsize = torch.min(torch.max(0.99*cohen_val, torch.tensor(0.05)), torch.tensor(cutoff_step))
        lambda_val *= 1.05

    yp_grad = torch.autograd.grad(objective_recorded, yp)[0].detach()
    yp_grad = yp_grad / (1e-5 + torch.linalg.norm(yp_grad))
    
      
    if (torch.isnan(torch.sum(yp_grad)) == 0) and ((E0 - E1) < (1 - 1e-5)):
        yp_new = yp - stepsize*yp_grad
    else:
        yp_new = yp + 0.05*torch.randn_like(yp)
    
    yp_new = torch.clip(yp_new, 0, 1).detach()
    yp_new.requires_grad = True


    if max_class != class_set[0]:
        original_gap = (E0 - E1).detach()
        new_gap = E0 - E1
        if printing:
            print('AT: ', E0 - E1)
        if E0 > E1:
            distance = norm_distance(yp, images, inpt_flag) #torch.linalg.norm(yp - images)
            lambda_val *= 1.15
            if printing:
                print('First delta: ', delta)
            delta = torch.max(delta - (0.9*(E0 - E1)), torch.tensor(0.001))

            if distance < min_recorded:
                min_recorded = distance
                return yp_new, yp_grad, min_recorded, yp.detach().clone(), lambda_val, delta
        else:
            if printing:
                print('Second delta: ', delta)
            delta = delta + (1.01*(E1 - E0))
            if printing:
                print('Changing delta to: ', delta)
                
    return yp_new, yp_grad, min_recorded, yp_best, lambda_val, delta
      
    
def fgsm_attack_v4(model, images, labels, samples, sigma, device, eps=0.3, alpha=2/255, iters=40, z_val=4, probabilistic=False):
    inpt_flag = True if images.shape[1] == 1 else False
    start_time = time.time()
    first_flag = False

    device = images.device
    images = images.to(device)
    labels = labels.to(device)
    loss = nn.CrossEntropyLoss()
    
    alpha_probability = 0.5*(1 - stats_norm.cdf(z_val))
        
    ori_images = images.data.detach()
    flag = False
    adv_input = images.repeat(samples, 1, 1, 1)
    adv_input += torch.randn_like(adv_input) * sigma #torch.normal(mean=torch.zeros_like(adv_input), std=sigma)
    
    if adv_input.shape[1] != 3:
        adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters
    
    sm, _, _ = return_wrapper(model, adv_input, device)
    base_laebl = torch.argmax(sm)
    
    min_val = torch.tensor(1e6)
    min_recorded = None
    for i in range(iters): 
        images.requires_grad = True

        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma 
        
        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters


        sm, _, _ = return_wrapper(model, adv_input, device)
        pred = torch.argmax(sm)
        if pred != labels[0]:  
            second_class = sm.detach().clone()
            second_class[pred] = 0
            second_class = torch.argmax(second_class)
            pred, second_class = pred.detach().cpu().numpy(), second_class.detach().cpu().numpy()
            m_c = multi_conf(samples * sm.detach().cpu().numpy(), alpha=alpha_probability) 

            lab = labels[0].detach().cpu().numpy()

            if m_c[pred, 0] > m_c[lab, 1]:
                flag = True
                delta = norm_distance(images, ori_images, inpt_flag) 
                if delta < min_val:
                    min_val = delta
                    min_recorded = images.detach()
                    if first_flag is False:
                        first_pgd = min_val.detach().clone().cpu().numpy()
                        first_time = time.time() - start_time
                        first_flag = True


        model.zero_grad()
        cost = loss(sm.reshape(1, -1), labels).to(device)
        cost.backward()

        adv_images = images + alpha*images.grad.sign()
        eta = torch.clamp(adv_images - ori_images, min=-eps, max=eps)
        images = torch.clamp(ori_images + eta, min=0, max=1).detach_()        

    final_pgd_time = time.time() - start_time
    if first_flag is False:
        first_time = final_pgd_time
        first_pgd = min_val.detach().cpu().numpy()
    
    return images, flag, pred, min_val.detach().cpu().numpy(), min_recorded, first_pgd, first_time, final_pgd_time
     

def deepfool_attack(model, images, label, device, classes, samples, sigma, overshoot=0.02, nb_candidate=10, max_iter=100):
    start_time = time.time()
    ori_images = images.detach()
    nb_candidate = np.min((nb_candidate, classes))

    images.requires_grad_()

    if ori_images.shape[1] != 3:
        images = images.repeat(1,3,1,1)


    adv_input = images.repeat(samples, 1, 1, 1)
    adv_input = adv_input + torch.randn_like(adv_input) * sigma

    logits, _, _ = return_wrapper(model, adv_input, device)

    pred = torch.argmax(logits) 

    w = torch.squeeze(torch.zeros(adv_input.size()[1:])).to(device)
    r_tot = torch.zeros(images.size()).to(device)

    
    iteration = 0

    basic_predic = basic_predict(model, adv_input, device)  


    while ((pred == label) and iteration < max_iter):
        predictions_val = torch.topk(logits, nb_candidate)[0]
        gradients = torch.stack(jacobian(predictions_val, images, nb_candidate), dim=1) 
        with torch.no_grad():
            pert = 1e10
            if pred != label:
                continue
            for k in range(1, nb_candidate):
                w_k = gradients[0, k, ...] - gradients[0, 0, ...]
                f_k = predictions_val[k] - predictions_val[0]
                pert_k = (f_k.abs() + 0.00001) / w_k.view(-1).norm()
                if pert_k < pert:
                    pert = pert_k
                    w = w_k
            r_i = pert * w / w.view(-1).norm()
            r_tot += r_i #r_tot[0, ...] = r_tot[0, ...] + r_i

        if torch.sum(torch.isnan(r_tot)) > 0:
            return False, None, 0., time.time() - start_time
        images = torch.clamp(r_tot + images, 0,1).detach().requires_grad_()
        
        if ori_images.shape[1] != 3:
            images = images.mean(dim=1).repeat(1, 3, 1, 1)

        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input = adv_input + torch.randn_like(adv_input) * sigma

        logits, _, _ = return_wrapper(model, adv_input, device)
        pred = torch.argmax(logits) 

        iteration = iteration + 1



    adv_x = torch.clamp((1 + overshoot) * r_tot + images, 0., 1.)
    distance = torch.linalg.norm(adv_x - ori_images).detach().cpu().numpy()
    flag = False
    if pred != label:
        flag = True

    return flag, adv_x, distance, time.time() - start_time


def jacobian(predictions, x, classes):
    list_derivatives = []

    for class_ind in range(classes):
        outputs = predictions[class_ind] #[:, class_ind]
        derivatives, = torch.autograd.grad(outputs, x, grad_outputs=torch.ones_like(outputs), retain_graph=True)
        list_derivatives.append(derivatives)

    return list_derivatives
    
def pgd_attack(model, images, labels, samples, sigma, device, eps=0.3, alpha=20/255, iters=40, z_val=4, probabilistic=False):
    # This is now the Iterative Fast Gradient Method for L2 Norms
    inpt_flag = True if images.shape[1] == 1 else False
    start_time = time.time()
    first_flag = False

    device = images.device
    images = images.to(device)
    labels = labels.to(device)
    loss = nn.CrossEntropyLoss()
    
    alpha_probability = 0.5*(1 - stats_norm.cdf(z_val))
        
    ori_images = images.data.detach()
    flag = False
    adv_input = images.repeat(samples, 1, 1, 1)
    adv_input += torch.randn_like(adv_input) * sigma
    
    if adv_input.shape[1] != 3:
        adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters


    
    sm, _, _ = return_wrapper(model, adv_input, device)
    base_laebl = torch.argmax(sm)
    
    min_val = torch.tensor(1e6)
    min_recorded = None
    for i in range(iters): 
        images.requires_grad = True

        adv_input = images.repeat(samples, 1, 1, 1)
        adv_input += torch.randn_like(adv_input) * sigma 
        
        if adv_input.shape[1] != 3:
            adv_input = adv_input.repeat(1,3,1,1) # Could also do x = self.conv1, but this is just introducing more trainable parameters


        sm, _, _ = return_wrapper(model, adv_input, device)
        pred = torch.argmax(sm)
        if pred != labels[0]:  
            second_class = sm.detach().clone()

            second_class[pred] = 0
            second_class = torch.argmax(second_class)
            pred, second_class = pred.detach().cpu().numpy(), second_class.detach().cpu().numpy()
            m_c = multi_conf(samples * sm.detach().cpu().numpy(), alpha=alpha_probability) #multi_conf(samples * sm[0,:].detach().cpu().numpy(), alpha=alpha_probability)
            lab = labels[0].detach().cpu().numpy()
            if m_c[pred, 0] > m_c[lab, 1]:
                flag = True
                delta = norm_distance(images, ori_images, inpt_flag) #torch.linalg.norm(images - ori_images)
                if delta < min_val:
                    final_pgd_time = time.time() - start_time
                    min_val = delta
                    min_recorded = images.detach()
                    if first_flag is False:
                        first_pgd = min_val.detach().clone().cpu().numpy()
                        first_time = time.time() - start_time
                        first_flag = True


        model.zero_grad()
        cost = loss(sm.reshape(1, -1), labels).to(device)
        cost.backward()
        
        grad = images.grad
        grad_norm = torch.linalg.norm(grad)
        images = torch.clamp(images.detach() + alpha * (grad / (grad_norm + 1e-12)), min=0, max=1).detach_()
        
    if first_flag is False:
        final_pgd_time = time.time() - start_time
        first_time = final_pgd_time
        first_pgd = min_val.detach().cpu().numpy()

    
    return images, flag, pred, min_val.detach().cpu().numpy(), min_recorded, first_pgd, first_time, final_pgd_time
    
def direct_loop(model, classes, labels, samples, yp, images, z_val, sigma, device, iters=100, new=True, cutoff_step=0.5):
    start_time = time.time()
    first_flag = False
    
    class_set = torch.zeros(classes).to(device)
    class_set[0] = labels[0]
    class_dummy = torch.arange(classes).to(device)

    class_set[1:] = class_dummy[class_dummy != labels[0]]
    min_recorded = 1000
    stationary_counter = 0
    sample_size = samples #int(model.samples / 5)
    yp_best, lambda_val, delta_val = None, 0.5, None
    inpt_flag = True if images.shape[1] == 1 else False
    for iii in range(iters):
        min_recorded_old = min_recorded
        if new:
            yp, yp_grad, min_recorded, yp_best, lambda_val, delta_val = direct_attack(model, yp, images, sample_size, class_set, z_val, min_recorded, sigma, device, inpt_flag, yp_best=yp_best, lambda_val=lambda_val, delta=delta_val, cutoff_step=cutoff_step)
        else:
            print('Unimplemented')
            
        if (min_recorded_old - 1e-5) > min_recorded:
            new_d_time = time.time() - start_time
        
        if min_recorded < 1000:
            if first_flag is False:
                first_time = time.time() - start_time
                first_radii = min_recorded.detach().cpu().numpy()
                first_flag = True                        
            #print('This is min recorded ', min_recorded)
            if stationary_counter == 0:
                sample_size = samples
            if torch.abs(min_recorded - min_recorded_old) < 1e-5:
                stationary_counter += 1
            else:
                stationary_counter = 0
            #if stationary_counter > 15:
                #break # Removing this break condition

    
    if min_recorded >= 1000:
        new_d_time = time.time() - start_time
        first_time = new_d_time
        first_radii = min_recorded

    
    return new_d_time, min_recorded, first_time, first_radii
    

