import torch
#from resnet import *
import threading
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torchvision
import robustbench
import numpy as np

from robustbench.utils import load_model


from wrn import *

base_path = "Enter base path"


#Threading from: https://www.geeksforgeeks.org/multithreading-python-set-1/
class ECAC(nn.Module):
    """
    TODO: Code an algo where we parametrize stuff as:
     - correct name for p1 and p2
     - Also when the nudging does not result in desired behavior, aka, agreement with weakM what shall be done?
         If strongM prediction does not changes: it likely suggest that strongM prediction was correct so we revert it.
         If however both the prediction chnages: we can not say which prediction (out of 4) is correct, so for simplicity we shall simply 
         
         
    Case 1: strongM == weakM 
            If strongM' == weakM no issue. keep the x', here strongM' == strongM as well
            else i.e, strongM' != weakM ??
                    If strongM' == weakM': Revert??? Keep as we are agreeing so it is likely correct.
                    else strongM' != weakM': Definitely revert?
    Case 2: strongM != weakM: It is much easier to mislead weakM!!
            if strongM' == weakM: We likely fix it. Also the weakest step of our defense if adversary knows the defense.
            else i.e, strongM' != weakM:
                If strongM' == strongM:  weakM was likely incorrect. So we revert the changes.
                else If strongM' == weakM': but both have flipped! No easy way to tell: Revert??
                else strongM' != weakM': Same as above, but definitely revert.
    ===========
    This is over-engineering!
    What's wrong with reverting if it fails to agree with weakM?
    """
    def __init__(self, strongM, weakM, step_size = 0.02, epsilon = 0.031 ):
        super(ECAC, self).__init__()
        self.strongM   = strongM
        self.weakM     = weakM
        self.step_size = step_size
        self.epsilon   = epsilon
        
    def pgdTowards(self, model, X, Y, num_steps = 2, step_size = 0.007):
        """
        Using model move towards y.
        This can be invoked as a separate thread, and can be coded to use multiple GPUs present.
        """
        random_noise = torch.FloatTensor(*X.shape).uniform_(-self.epsilon, self.epsilon).to(X.device)
        X_pgd = Variable(X.data + random_noise, requires_grad=True)
        for idx in range(num_steps):
            opt = optim.SGD([X_pgd], lr=1e-3)
            opt.zero_grad()    
            with torch.enable_grad():
                logit = model(X_pgd)
                loss = nn.CrossEntropyLoss()(logit, Y)
            loss.backward()
            eta = step_size * X_pgd.grad.data.sign()
            X_pgd = Variable(X_pgd.data - eta, requires_grad=True)
            eta = torch.clamp(X_pgd.data - X.data, -self.epsilon, self.epsilon)
            X_pgd = Variable(X.data + eta, requires_grad=True)
            X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
        #print("Exec")
        self.delta = (X_pgd - X).detach().clone() #Delta is anyway detached. 
        #
    def pgdTowards_Both(self, X, Y, num_steps = 2, step_size = 0.007):
        """
        Using model move towards y.
        This can be invoked as a separate thread, and can be coded to use multiple GPUs present.
        """
        random_noise = torch.FloatTensor(*X.shape).uniform_(-self.epsilon, self.epsilon).to(X.device)
        X_pgd = Variable(X.data + random_noise, requires_grad=True)
        for idx in range(num_steps):
            opt = optim.SGD([X_pgd], lr=1e-3)
            opt.zero_grad()    
            with torch.enable_grad():
                logitS = self.strongM(X_pgd)
                lossS = nn.CrossEntropyLoss()(logitS, Y)
                #
                logitW = self.weakM(X_pgd)
                lossW = nn.CrossEntropyLoss()(logitW, Y)
                loss = lossS + lossW
            loss.backward()
            eta = step_size * X_pgd.grad.data.sign()
            X_pgd = Variable(X_pgd.data - eta, requires_grad=True)
            eta = torch.clamp(X_pgd.data - X.data, -self.epsilon, self.epsilon)
            X_pgd = Variable(X.data + eta, requires_grad=True)
            X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
        #print("Exec")
        self.delta = (X_pgd - X).detach().clone() #Delta is anyway detached. 
        #
    def forward(self, X, y = None):
        #strongLogit = self.strongM(X)
        weakLogit   = self.weakM(X)
        #
        t = threading.Thread(target=self.pgdTowards_Both, args=(X, weakLogit.data.max(1)[1], 
                                                                       1, self.step_size))
        t.start()
        t.join()
        #
        newStrongLogit = self.strongM(X + self.delta)
        toChange = newStrongLogit.data.max(1)[1] == weakLogit.data.max(1)[1]
        self.delta[toChange == False] = 0
        #X[toChange] += (self.delta[toChange]).detach().clone()
        return self.strongM(X + self.delta)


#################################
#Now the AAA model form their code:


def loss(y, logits, targeted=False, loss_type='margin_loss'):
    if loss_type == 'margin_loss':
        preds_correct_class = (logits * y).sum(1, keepdims=True)
        diff = preds_correct_class - logits
        diff[y] = np.inf
        margin = diff.min(1, keepdims=True)
        loss = margin * -1 if targeted else margin
    elif loss_type == 'cross_entropy':
        probs = softmax(logits)
        loss = -np.log(probs[y])
        loss = loss * -1 if not targeted else loss
    else:
        raise ValueError('Wrong loss.')
    return loss.flatten()

class AAALinear(nn.Module):
    def __init__(self, dataset, arch, norm, model_dir,
        device="cuda", batch_size=1000, attractor_interval=4,
        reverse_step=1, num_iter=100, calibration_loss_weight=5,
        optimizer_lr=0.1, do_softmax=False, **kwargs):
        super(AAALinear, self).__init__()
        print("In AAALinear")
        self.dataset = dataset
        try:
            self.cnn = getattr(torchvision.models, arch)(pretrained=True).to(device).eval()
            self.mean = [0.485, 0.456, 0.406]
            self.std = [0.229, 0.224, 0.225]
        except AttributeError:
            self.cnn = robustbench.utils.load_model(model_name=arch, dataset=dataset, threat_model=norm, model_dir=model_dir).to(device).eval()
            self.mean = [0] #if dataset != 'imagenet' else [0, 0, 0]
            self.std = [1] #if dataset != 'imagenet' else [1, 1, 1]
            self.cnn.to(device)

        self.loss = loss
        self.batch_size = batch_size
        self.device = device

        self.attractor_interval = attractor_interval
        self.reverse_step = reverse_step
        self.dev = 0.5
        self.optimizer_lr = optimizer_lr
        self.calibration_loss_weight = calibration_loss_weight
        self.num_iter = num_iter
        self.arch_ori = arch
        self.arch = '%s_AAAlinear-Lr-%.1f-Ai-%d-Cw-%d' % (self.arch_ori, self.reverse_step, self.attractor_interval, self.calibration_loss_weight)
        self.temperature = 1 # 2.08333 #
        self.do_softmax = do_softmax

    def set_hp(self, reverse_step, attractor_interval=6, calibration_loss_weight=5):
        self.attractor_interval = attractor_interval
        self.reverse_step = reverse_step
        self.calibration_loss_weight = calibration_loss_weight
        self.arch = '%s_AAAlinear-Lr-%.1f-Ai-%d-Cw-%d' % (self.arch_ori, self.reverse_step, self.attractor_interval, self.calibration_loss_weight)

    def forward_undefended(self, x): return predict(x, self.cnn, self.batch_size, self.device, self.mean, self.std)

    def get_tuned_temperature(self):
        t_dict = {
            'Standard': 2.08333,
            'resnet50': 1.1236,
            'resnext101_32x8d': 1.26582,
            'vit_b_16': 0.94,
            'wide_resnet50_2': 1.20482,
            'Rebuffi2021Fixing_28_10_cutmix_ddpm': 0.607,
            'Salman2020Do_50_2': 0.83,
            'Dai2021Parameterizing': 0.431,
            'Rade2021Helper_extra': 0.58
        }
        return t_dict.get(self.arch_ori, None)

    def temperature_rescaling(self, x_val, y_val, step_size=0.001):
        ts, eces = [], []
        ece_best, y_best = 100, 1
        y_pred = self.forward_undefended(x_val)
        for t in np.arange(0, 1, step_size):
            y_pred1 = y_pred / t
            y_pred2 = y_pred * t

            ts += [t, 1/t]
            ece1, ece2 = ece_score(y_pred1, y_val), ece_score(y_pred2, y_val)
            eces += [ece1, ece2]
            if ece1 < ece_best:
                ece_best = ece1
                t_best = t
            if ece2 < ece_best:
                ece_best = ece2
                t_best = 1/t
            print('t-curr=%.3f, acc=%.2f, %.2f, ece=%.4f, %.4f, t-best=%.5f, ece-best=%.4f' %
            (t, (y_pred1.argmax(1) == y_val.argmax(1)).mean() * 100, (y_pred2.argmax(1) == y_val.argmax(1)).mean() * 100,
            ece1 * 100, ece2 * 100,
            t_best, ece_best * 100))
        self.temperature = t_best

        plt.rcParams["figure.dpi"] = 500
        plt.rcParams["font.family"] = "times new roman"
        plt.scatter(ts, eces, color='#9467bd')
        plt.xscale('log')
        plt.xlabel('temperature')
        plt.ylabel('ece on validation set')
        plt.savefig('demo/t-%s-%.4f.png' % (self.arch, self.temperature))
        plt.close()

    def temperature_rescaling_with_aaa(self, x_val, y_val, step_size=0.001):
        self.temperature = self.get_tuned_temperature()
        if self.temperature is not None: return

        ts, eces = [], []
        ece_best, y_best = 100, 1
        for t in np.arange(0, 1, step_size):
            self.temperature = t
            y_pred1 = self.forward(x_val)
            self.temperature = 1/t
            y_pred2 = self.forward(x_val)

            ts += [t, 1/t]
            ece1, ece2 = ece_score(y_pred1, y_val), ece_score(y_pred2, y_val)
            eces += [ece1, ece2]
            if ece1 < ece_best:
                ece_best = ece1
                t_best = t
            if ece2 < ece_best:
                ece_best = ece2
                t_best = 1/t
            print('t-curr=%.3f, acc=%.2f, %.2f, ece=%.4f, %.4f, t-best=%.5f, ece-best=%.4f' %
            (t, (y_pred1.argmax(1) == y_val.argmax(1)).mean() * 100, (y_pred2.argmax(1) == y_val.argmax(1)).mean() * 100,
            ece1 * 100, ece2 * 100,
            t_best, ece_best * 100))
        self.temperature = t_best

        plt.rcParams["figure.dpi"] = 500
        plt.rcParams["font.family"] = "times new roman"
        plt.scatter(ts, eces, color='#9467bd')
        plt.xscale('log')
        plt.xlabel('temperature')
        plt.ylabel('ece on validation set')
        plt.savefig('demo/taaa-%s-%.4f.png' % (self.arch, self.temperature))
        plt.close()

    def forward(self, x):
        verbose = False
        if isinstance(x, np.ndarray):
            x = np.floor(x * 255.0) / 255.0
            x = ((x - np.array(self.mean)[np.newaxis, :, np.newaxis, np.newaxis]) / np.array(self.std)[np.newaxis, :, np.newaxis, np.newaxis]).astype(np.float32)
        else:
            x = torch.floor(x * 255.0) / 255.0
            x = ((x - torch.as_tensor(self.mean, device=self.device)[None, :, None, None]) / torch.as_tensor(self.std, device=self.device)[None, :, None, None])
        n_batches = math.ceil(x.shape[0] / self.batch_size)
        logits_list = []

        for counter in range(n_batches):
            with torch.no_grad():
                if verbose: print('predicting', counter, '/', n_batches, end='\r')
                x_curr = x[counter * self.batch_size:(counter + 1) * self.batch_size]
                if isinstance(x, np.ndarray): x_curr = torch.as_tensor(x_curr, device=self.device)
                logits = self.cnn(x_curr)

            logits_ori = logits.detach()
            prob_ori = F.softmax(logits_ori / self.temperature, dim=1)
            prob_max_ori = prob_ori.max(1)[0] ###
            value, index_ori = torch.topk(logits_ori, k=2, dim=1)
            #"""
            mask_first = torch.zeros(logits.shape, device=self.device)
            mask_first[torch.arange(logits.shape[0]), index_ori[:, 0]] = 1
            mask_second = torch.zeros(logits.shape, device=self.device)
            mask_second[torch.arange(logits.shape[0]), index_ori[:, 1]] = 1
            #"""

            margin_ori = value[:, 0] - value[:, 1]
            attractor = ((margin_ori / self.attractor_interval + self.dev).round() - self.dev) * self.attractor_interval
            target = attractor - self.reverse_step * (margin_ori - attractor)
            diff_ori = (margin_ori - target)
            real_diff_ori = margin_ori - attractor
            #"""
            # This is likely the optimization which the paper talks about!
            with torch.enable_grad():
                logits.requires_grad = True
                optimizer = torch.optim.Adam([logits], lr=self.optimizer_lr)
                i = 0
                los_reverse_rate = 0
                prd_maintain_rate = 0
                for i in range(self.num_iter):
                #while i < self.num_iter or los_reverse_rate != 1 or prd_maintain_rate != 1:
                    prob = F.softmax(logits, dim=1)
                    #loss_calibration = (prob.max(1)[0] - prob_max_ori).abs().mean()
                    loss_calibration = ((prob * mask_first).max(1)[0] - prob_max_ori).abs().mean() # better
                    #loss_calibration = (prob - prob_ori).abs().mean()

                    value, index = torch.topk(logits, k=2, dim=1)
                    margin = value[:, 0] - value[:, 1]
                    #margin = (logits * mask_first).max(1)[0] - (logits * mask_second).max(1)[0]

                    diff = (margin - target)
                    real_diff = margin - attractor
                    loss_defense = diff.abs().mean()

                    loss = loss_defense + loss_calibration * self.calibration_loss_weight
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    #i += 1
                    los_reverse_rate = ((real_diff * real_diff_ori) < 0).float().mean()
                    prd_maintain_rate = (index_ori[:, 0] == index[:, 0]).float().mean()
                    #print('%d, %.2f, %.2f' % (i, los_reverse_rate * 100, prd_maintain_rate * 100), end='\r')
                    #print('%d, %.4f, %.4f, %.4f' % (itre, loss_calibration, loss_defense, loss))
                logits_list.append(logits.detach().cpu())
                #print('main [los=%.2f, prd=%.2f], margin [ori=%.2f, tar=%.2f, fnl=%.2f], logits [ori=%.2f, fnl=%.2f], prob [tar=%.2f, fnl=%.2f]' %
                    #(los_reverse_rate * 100, prd_maintain_rate * 100,
                    #margin_ori[0], target[0], margin[0], logits_ori.max(1)[0][0], logits.max(1)[0][0], prob_max_ori[0], prob.max(1)[0][0]))
            #"""
            #logits_list.append(logits_ori.detach().cpu() / self.temperature)
        logits = torch.vstack(logits_list)
        if isinstance(x, np.ndarray): logits = logits.numpy()
        if self.do_softmax: logits = softmax(logits)
        return logits.to(self.device)






################################
class ImageNormalizer(nn.Module):
    def __init__(self, mean, std):
        super(ImageNormalizer, self).__init__()
        self.mean = mean
        self.std = std

    def forward(self, x):
        # Normalize x
        return (x - self.mean) / self.std


def get_model(m_type, m_name, args):
    global base_path
    print("In load model: args:", args)
    m_type = m_type.lower()
    m_name = m_name.lower()
    if m_type == "rn50":
        if m_name == "nat":
            arch = 'Standard_R50'
            dataset = 'imagenet'
            norm = 'Linf'
            model = load_model(model_name=arch,
                             dataset=dataset, 
                             threat_model=norm, 
                             model_dir="rbmodels/"
                            ).to("cuda").eval()
            model.eval()
            return model
    
    if m_type == "wrn50":
        if m_name == "nat":
            pre_model = torch.hub.load('pytorch/vision:v0.13.0', 
                                           'wide_resnet50_2', 
                                           pretrained=True).to("cuda").eval()
            mean = torch.tensor([0.485, 0.456, 0.406], device='cuda:0').reshape([1, -1, 1, 1])
            std = torch.tensor([0.229, 0.224, 0.225], device='cuda:0').reshape([1, -1, 1, 1])
            normalizer = ImageNormalizer(mean, std)
            model = nn.Sequential(
                normalizer,
                pre_model
            )
            return model
    if m_type == "wrn50":
        if m_name == "sat" or m_name == 'Salman2020Do_50_2':
            arch = 'Salman2020Do_50_2'
            dataset = 'imagenet'
            norm = 'Linf'
            model = load_model(model_name=arch,
                             dataset=dataset, 
                             threat_model=norm, 
                             model_dir="rbmodels/"
                            ).to("cuda").eval()
            model.eval()
            return model
    if m_name == "ecac":
            strongM = get_model(m_type, args["strongM"], None)
            weakM   = get_model(m_type, "nat", None)
            return ECAC(strongM, weakM, step_size = float(args["nudge"]), epsilon = float(args["epsilon"]))
    if m_type == "wrn50":
        if m_name == "aaa":
            return AAALinear( #Parameters extracted using breakpoint in VSCode in the provided code.
                dataset='imagenet',
                arch="wide_resnet50_2", #args.model,
                norm='Linf', #'L2' if args.l2 else 'Linf',
                device=torch.device(args["device"]),
                batch_size=int(args["batch_size"]),
                model_dir="rbmodels", #args.model_dir,
                do_softmax=False, #args.loss == 'prob',  #False

                n_in=0, #(0.02 if ((args.model == 'Standard' and args.dataset == 'cifar10') or ('Salman2020Do' not in args.model and args.dataset == 'imagenet')) else 0.05) if (args.defense == 'inRND') else 0,   #0
                n_out=0, #(1 if args.model == 'Standard' else 0.3) if (args.defense == 'outRND') else 0,

                attractor_interval=6, #args.attractor_interval,
                reverse_step=1, #args.lr,
                calibration_loss_weight= 5, #args.calibration_loss_weight,
                num_iter=100, #args.aaa_iter,
                optimizer_lr=0.1 #args.aaa_optimizer_lr
            )
        
        
        
        
        
    print("\n\n\n\n Failed to load model. \n model is None.")
