import torch
import numpy as np


import torch.nn as nn

import utils


def unnormalize(x,dataset='imagenet'):
    if dataset=='cifar':
        return utils.invert_normalization(x, 'cifar')
    else:
        return utils.invert_normalization(x, 'imagenet')


def normalize(x,dataset='imagenet'):
    if dataset=='cifar':

        return utils.apply_normalization(x, 'cifar')
    else:
        return utils.apply_normalization(x, 'imagenet')


from torch.distributions import Beta


# [0,1] -> [-1,1]
# mean, std [0.5, 0.5]
def model_noise(imgs_tensor, sigma=0.001, alpha=0, beta=0):
    RN = torch.randn_like(imgs_tensor)
    if alpha > 0 and beta > 0:
        m = Beta(torch.FloatTensor([alpha]), torch.FloatTensor([beta]))
        mm = m.sample((imgs_tensor.size()[0],)).view(imgs_tensor.size()[0], 1, 1, 1).cuda()
        RN = RN * sigma * mm
    else:
        RN = RN * sigma
    return RN



class BlackBoxModel(nn.Module):
    def __init__(self, model, defense='gaussian',dataset='imagenet', avg_iter=1):
        super(BlackBoxModel, self).__init__()
        self.model = model
        self.attack_mode = False
        self.dataset=dataset
        self.avg_iter=avg_iter
        if dataset=='cifar':
            self.l2_epsilon=1.0
            # k=torch.ones((3,32,32))*8./255.
            # print(k.norm())
            self.linf_epsilon=8./255.
            self.chk_idx = [500, 1000, 2000, 2500, 5000, 7500, 10000]
        else:
            self.l2_epsilon=5.
            self.linf_epsilon=0.05
            self.chk_idx = [1000, 5000, 10000, 20000]
        self.defense=defense

    def apply_defense(self,batch):

        batch_size = batch.size()[0]
        if self.sigma > 0:
            noised_batch = (batch + model_noise(batch, sigma=self.sigma, alpha=self.alpha,
                                               beta=self.beta)).clamp(0, 1)
        else:
            noised_batch = batch
        if self.defense == 'rnp':
            if batch_size == 1:
                rnd_size = np.random.randint(224, 248 + 1) # 232
                noised_batch = torch.nn.functional.upsample(noised_batch, size=(rnd_size, rnd_size), mode='nearest')
                second_max = 248 - rnd_size
                a = np.random.randint(0, second_max + 1)
                b = np.random.randint(0, second_max + 1)
                pads = (b, second_max - b, a, second_max - a)  # pad last dim by 1 on each side
                noised_batch = normalize(noised_batch,self.dataset)
                resized_batch = torch.nn.functional.pad(noised_batch, pads, "constant", 0)  # effectively zero padding
            else:
                resized_batch = torch.zeros((batch_size, 3, 248, 248)).cuda()
                for nn in range(batch_size):
                    cur_img = noised_batch[nn:nn + 1]
                    rnd_size = np.random.randint(224, 248 + 1)
                    cur_img = torch.nn.functional.upsample(cur_img, size=(rnd_size, rnd_size), mode='nearest')
                    second_max = 248 - rnd_size
                    a = np.random.randint(0, second_max + 1)
                    b = np.random.randint(0, second_max + 1)
                    pads = (b, second_max - b, a, second_max - a)  # pad last dim by 1 on each side
                    cur_img = normalize(cur_img,self.dataset)
                    cur_img = torch.nn.functional.pad(cur_img, pads, "constant", 0)  # effectively zero padding
                    resized_batch[nn] = cur_img
            noised_batch=resized_batch
        else:
            noised_batch = normalize(noised_batch,self.dataset)


        return noised_batch
    def forward(self, batch,unnormalization=True):
        if self.attack_mode == True and self.num_queries<=self.max_num_queries:

            batch_size = batch.size()[0]
            self.num_queries += batch_size

            if unnormalization==True:
                batch = unnormalize(batch,self.dataset)
            batch = batch.clamp(0, 1)

            perturbation = batch - self.original_image
            l2_norm = perturbation.view(batch_size, -1).norm(2, 1)
            linf_norm = perturbation.view(batch_size, -1).abs().max(1)[0]

            #if self.num_queries%100==0:
            #    print(self.num_queries, l2_norm, linf_norm,flush=True)

            if self.avg_iter==1:
                noised_batch=self.apply_defense(batch)

                if self.dataset == 'cifar':
                    if batch_size==1:
                        clean_batch = normalize(batch, self.dataset)
                        mixed_batch=torch.cat((noised_batch,clean_batch),dim=0)
                        output_mixed = self.model(mixed_batch)
                        output=output_mixed[0:1]
                        output_clean=output_mixed[1:2]
                        _, preds_clean = output_clean.data.max(1)
                        # print(output,output_clean)
                    else:
                        clean_batch = normalize(batch, self.dataset)
                        output_clean = self.model(clean_batch)
                        _, preds_clean = output_clean.data.max(1)
                        output = self.model(noised_batch)
                else:
                    output = self.model(noised_batch)


                prob = torch.index_select(torch.nn.Softmax()(output).data, 1, self.y)
                _, pred = output.data.max(1)

            else:
                clean_batch = normalize(batch, self.dataset)
                output_clean = self.model(clean_batch)
                _, preds_clean = output_clean.data.max(1)
                preds_ensemble = torch.zeros_like(output_clean).cuda()

                for n in range(self.avg_iter):
                    noised_batch = self.apply_defense(batch)
                    output=self.model(noised_batch)
                    _, pred = output.data.max(1)
                    preds_ensemble[torch.range(0,preds_ensemble.size(0)-1).long(),pred]+=1
                    if n==0:
                        output_sum = output
                    else:
                        output_sum = output_sum+output

                output=output_sum/self.avg_iter
                # print(output)
                prob = torch.index_select(torch.nn.Softmax()(output).data, 1, self.y)
                _, pred = preds_ensemble.data.max(1)



            if self.dataset == 'cifar':
              self.log_ne_count+=(preds_clean.ne(pred)).float().sum().item()

            if self.log_l_2_query_count < 0:
                if torch.min(pred.eq(self.y).float()) == 0:
                    if torch.min(l2_norm) <= self.l2_epsilon:
                        for i in range(batch_size):
                            if l2_norm[i] < self.l2_epsilon:
                                self.log_l_2_query_count = i + self.num_queries - batch_size + 1
                                break
            if self.log_l_inf_query_count < 0:
                if torch.min(pred.eq(self.y).float()) == 0:
                    if torch.min(linf_norm) <= self.linf_epsilon:
                        for i in range(batch_size):
                            if linf_norm[i] <= self.linf_epsilon:
                                self.log_l_inf_query_count = i + self.num_queries - batch_size + 1
                                break
            if self.next_log_queries <= self.num_queries and self.next_log_queries <= self.max_num_queries:
                cur_log_idx = self.next_log_queries // self.log_interval - 1
                cur_idx = batch_size - (self.num_queries - self.next_log_queries) - 1
                self.log_query_point[cur_log_idx] = self.next_log_queries
                self.log_l_2[cur_log_idx] = l2_norm[cur_idx]
                self.log_l_inf[cur_log_idx] = linf_norm[cur_idx]
                self.log_prob[cur_log_idx] = prob[cur_idx]
                self.log_acc[cur_log_idx] = (pred[cur_idx] == self.y).float()
                if self.next_log_queries in self.chk_idx:
                    if self.log_adv is None:
                        self.log_adv = batch[cur_idx:cur_idx + 1]
                    else:
                        self.log_adv = torch.cat([self.log_adv, batch[cur_idx:cur_idx + 1]], dim=0)
                self.next_log_queries += self.log_interval


        else:
            output = self.model(batch)

        return output

    def predict_label(self, batch):
        output=self.forward(batch,unnormalization=False)
        _, preds = output.data.max(1)
        return preds

    def set_log(self, batch,unnormalization=True):

        if self.next_log_queries<=self.max_num_queries and self.num_queries % self.log_interval!=0:

            batch_size = batch.size()[0]
            # if self.num_queries%100==0:
            #     print(self.num_queries)

            if unnormalization==True:
                batch = unnormalize(batch,self.dataset)
            # Increment # of queries

            perturbation = batch - self.original_image
            l2_norm = perturbation.view(batch_size, -1).norm(2, 1)
            linf_norm = perturbation.view(batch_size, -1).abs().max(1)[0]
            #print(self.num_queries,l2_norm)
            noised_batch = self.apply_defense(batch)
            output = self.model(noised_batch)

            prob = torch.index_select(torch.nn.Softmax()(output).data, 1, self.y)
            _, pred = output.data.max(1)
            cur_log_idx = self.next_log_queries // self.log_interval-1
            cur_idx = 0
            self.log_query_point[cur_log_idx] = self.num_queries
            self.log_l_2[cur_log_idx] = l2_norm[cur_idx]
            self.log_l_inf[cur_log_idx] = linf_norm[cur_idx]
            self.log_prob[cur_log_idx] = prob[cur_idx]
            self.log_acc[cur_log_idx] = (pred[cur_idx] == self.y).float()
            if self.log_adv is None:
                self.log_adv = batch[cur_idx:cur_idx + 1]
            else:
                self.log_adv = torch.cat([self.log_adv, batch[cur_idx:cur_idx + 1]], dim=0)
        return
    def init_model(self, attack_setting, x, y):  # Batch argument.
        self.attack_mode = True
        self.original_image = torch.cuda.FloatTensor(x)
        self.y = torch.cuda.LongTensor(y)
        self.sigma = attack_setting['sigma']
        self.alpha = attack_setting['alpha']
        self.beta = attack_setting['beta']
        self.log_ne_count=0
        self.log_interval = attack_setting['log_interval']
        self.max_num_queries = attack_setting['max_num_queries']
        self.num_queries = 0
        self.next_log_queries = attack_setting['log_interval']
        self.attack = attack_setting['attack']
        self.log_queries = np.arange(0, self.max_num_queries, self.log_interval)
        self.log_query_point = np.zeros((self.max_num_queries // self.log_interval))
        self.log_l_2 = np.zeros((self.max_num_queries // self.log_interval))
        self.log_l_inf = np.zeros((self.max_num_queries // self.log_interval))
        self.log_acc = np.zeros((self.max_num_queries // self.log_interval))
        self.log_prob = np.zeros((self.max_num_queries // self.log_interval))
        self.log_adv = None
        self.log_l_2_query_count = -1
        self.log_l_inf_query_count = -1

    def get_num_queries(self):
        return self.num_queries  # Return # of queries.


    def get_log(self):
        if self.dataset=='cifar':
            return (
                self.log_query_point, self.log_prob, self.log_acc, self.log_l_2, self.log_l_inf,
                self.log_l_2_query_count,
                self.log_l_inf_query_count, self.log_adv,self.log_ne_count)
        else:
            return (
            self.log_query_point, self.log_prob, self.log_acc, self.log_l_2, self.log_l_inf, self.log_l_2_query_count,
            self.log_l_inf_query_count, self.log_adv)

