import numpy
import numpy as np
import torch
import torch.nn.functional as F
from data_load_cifa100_linux import MyDataset
from torch.utils.data import DataLoader
import collections as cls
import gc
def FGSM_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image,0, 1)
    return perturbed_image


def make_advers_samples_normal_model(model, test_dataloader, device,epsilon=None):
    adv_examples = []
    for v, data in enumerate(test_dataloader):
        x, y, t = data
        x, y, t = x.to(device), y.to(device), t.to(device)
        x.requires_grad=True
        output = model(x, t)

        # init_pred = output.argmax(dim=1)
        # if init_pred.item() != y.item():
        #     continue
        loss = torch.nn.functional.cross_entropy(output, y)
        # Calculate gradients of model in backward pass
        inputs_grad = torch.autograd.grad(loss, x)[0]  # [bs, input_dim ]
        perturbed_data = FGSM_attack(x, epsilon=epsilon, data_grad=inputs_grad)
        adv_examples.append([perturbed_data, y,t])

    print("the adver_sample's number:{}".format(len(adv_examples)))
    return adv_examples





def make_advers_samples_normal_model_resnet(model,test_dataloader,device,epsilon=None):
    adv_examples=[]
    s_adver=[]
    for v, data in enumerate(test_dataloader):
        x, y, t = data
        x, y, t = x.to(device), y.long().to(device), t.to(device)
        x.requires_grad = True
        output = model(x, t)
    #     for index in range(x.shape[0]):
    #         init_pred = output[index].argmax(dim=0)
    #         if init_pred.item() == y[index].item():
    #             s_adver.append((x[index].cpu(),y[index].cpu(),t[index].cpu()))
    #     #     continue
    #     # Calculate the loss
    # dataset__ = MyDataset(s_adver)
    # adver_dataloader =DataLoader(dataset=dataset__,batch_size=batchsize,shuffle=True)
    # for vv, data__ in enumerate(adver_dataloader):
    #     x_, y_, t_ = data__
    #     x_, y_, t_ = x_.to(device), y_.to(device), t_.to(device)
    #     x_.requires_grad = True
    #     output_ = model(x_, t_)
        loss = torch.nn.functional.cross_entropy(output, y)  # [bs]
        # print(f'loss shape: {loss.shape}')
        # Calculate gradients of model in backward pass
        inputs_grad = torch.autograd.grad(loss, x)[0]  # [bs, input_dim ]
        perturbed_data = FGSM_attack(x, epsilon=epsilon, data_grad=inputs_grad)
        adv_examples.append([perturbed_data, y, t])
    return adv_examples


def evaluate_acc_multihead(model,dataloader,device):
    model.eval()
    correct_numbers=0
    for v, data in enumerate(dataloader):
        x, y,t = data
        x, y,t= x.to(device), y.to(device),t.to(device)
        with torch.no_grad():
            logits = model(x,t)
        correct_numbers += torch.eq(logits.argmax(dim=1), y).sum().float().item()
    # print(correct_numbers)
    # print(len(dataloader.dataset))
    accuracy = correct_numbers / len(dataloader.dataset)
    model.train()
    return accuracy
def get_validation_loss(model,valid_loader,valid_losses=[],device=None):
    model.eval()
    valid_loss=0# prep model for evaluation
    for index,data in enumerate(valid_loader):
        # forward pass: compute predicted outputs by passing inputs to the model
        x,y,t=data
        x,y,t=x.to(device),y.to(device),t.to(device)
        model.zero_grad()
        output = model(x,t)
        # calculate the loss
        loss = torch.nn.functional.cross_entropy(output, y)
        # record validation loss
        valid_losses.append(loss.item())
        valid_loss=numpy.average(valid_losses)
    model.train()
    return valid_loss

def maps_reshape_strade(maps, filter_size, stride=1.):
    temp = []
    bs, channel, height, width = maps.shape
    # for i in range(bs):
    #     for h in range(int((height - filter_size) / stride)):
    #         for w in range(int((wide - filter_size) / stride)):
    #             temp.append(maps[i][:, h * stride:h * stride + filter_size,
    #                         w * stride:w * stride + filter_size].flatten().unsqueeze(dim=0))
    for map_ in maps:  # [channel, height, weight]  # dim of batch_size
        for h in range( 0, height - filter_size, stride ):
            for w in range( 0, width - filter_size, stride ):
                temp.append( map_[ :, h:h+filter_size, w:w+filter_size ].reshape(-1).unsqueeze(dim=0)) #[1,channel*k*k]
    # for h in range( 0, height - kernel_size, stride ):
    #     for w in range( 0, wi de - kernel_size, stride ):
    #         temp.append( map[ :, h:h+kernel_size, w:w+kernel_size].flatten().unsqueeze(dim=0))
    # new_temp = torch.vstack(temp)
    new_temp = torch.cat( temp, dim=0 )

    return new_temp


def gradient_cnn( x0, x1 ):
    # bs, c0, h0, w0 = x0.shape
    bs, c, height, wide = x1.shape
    # temp = torch.zeros( [bs, c0*h0*w0, c * h * w] )  # numb of elements in one column(c * h * w) of temp equals that in one sample in x1
    temp = torch.zeros([bs,  c * height * wide])
    for b in range(bs): #x1.shape = [bs, c, h ,w]
        # x1 = x1.sum( dim=0 )  # [bs, c, h ,w] -- [ c, h ,w]
        for col, w in enumerate(x1[b].reshape(-1)):  # [ c*h*w ]
            grad = torch.autograd.grad(w, x0, create_graph=True)[0]
            # w must be scalar value;  grad.shape = sample.shape
            # grad = grad.reshape(bs, -1).sum(dim=-1)  # [bs0, c0, h0 ,w0] -- [bs0, c0*h0*w0] -- [bs0]
            grad = grad.reshape(bs, -1).detach().cpu().sum(dim=1)  # [bs,c0, h0 ,w0] -- [ bs]
            temp[:, col] += grad
    return temp.reshape(bs, c, height, wide)
def gradient_fc( x0, x1 ):

    # bs, f0 = x0.shape
    bs, f1 = x1.shape
    temp = torch.zeros( [bs, f1] )  # numb of elements in one column(c * h * w) of temp equals that in one sample in x1
    for b in range(bs):
        for col, w in enumerate( x1[b] ):  # [f1 ]
            grad = torch.autograd.grad(w, x0, create_graph=True)[0]  #[f0]
            # w must be scalar value;  grad.shape = sample.shape
            grad = grad.reshape(bs,-1).detach().cpu().sum(dim=1)
            temp[:,  col] += grad

    return temp #[bs,f0, f1]
def maps_reshape_1( maps, pad, kernel_size, stride  ):

    # bs, channel, height, weight = maps.shape
    maps = F.pad(maps, ( pad, pad,pad, pad ), "constant", 0)
    bs, channel, height, weight = maps.shape

    temp = []
    bs, channel, height, wide = maps.shape
    for i in range(bs):
        for j in range(int((height - kernel_size) / stride)):
            for k in range(int((wide - kernel_size) / stride)):
                temp.append(maps[i, :, j * stride:j * stride + kernel_size,
                            k * stride:k * stride + kernel_size].flatten().unsqueeze(dim=0))
    new_temp = torch.cat(temp).cpu()
    return new_temp

def vector_extend( input_grad,  channel_in, height, width, stride=1, kernel_size=3, pad=1, ):
    #height, weight of the input with no-padding
    board = torch.ones( [ height, width ] )
    board = F.pad( board, ( pad, pad,pad, pad ), "constant", 0)

    index = []

    for h in range( 0, height - kernel_size, stride ):
        for w in range( 0, width - kernel_size, stride ):
            filter = board[ h:h+kernel_size, w:w+kernel_size ]
            weight_index = torch.where( filter.reshape(-1) == 1 )[0] #find the weight index on non-padding place

            board__ = board.clone()
            board__[ h:h+kernel_size, w:w+kernel_size  ] += 1 #helight (=2) the place of non-padding in in padding input
            board__ = board__[ pad:pad+height, pad:pad+width ] #find the place indexs of non-padding in no-padding input
            place_index = torch.where( board__.reshape(-1) == 2 )[0]

            weight_index = torch.cat( [weight_index + (kernel_size * kernel_size) * i for i in range(channel_in) ] )
            place_index = torch.cat([place_index + (height * width) * i for i in range(channel_in)])
            #on subsequent channel, the opeartion is repeat and the index = index + (height * weight)

            index.append( [place_index, weight_index] )

    input_grad_extend = []
    for row in input_grad:
        for place_index, weight_index in index:

            row_rectify = torch.zeros( channel_in * kernel_size * kernel_size )
            row_rectify[ weight_index ] = row[ place_index ]
            input_grad_extend .append( row_rectify )

    input_grad_extend = torch.vstack( input_grad_extend )

    return input_grad_extend
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


def bn_grad( grad,  s_weight , thegama ):

    b, c, h, width = grad.shape

    thegam = torch.tile( thegama.reshape(1,c,1,1),  dims=[b,1,h,width] )
    s_weight = torch.tile(s_weight.reshape(1, c, 1, 1), dims=[b, 1, h, width])

    grad = grad * ( 1 / torch.sqrt( thegam **2 +1e-5 )) * s_weight

    return grad





def get_m_thegama( grad, tensor1, s_weight ,device):
    # print("start time")

    b, c, h, width = tensor1.shape
    # print( f'tensor1.shape:{tensor1.shape}' )
    grad=grad.detach().cpu()
    tensor1=tensor1.detach().cpu()
    s_weight=s_weight.detach().cpu()
    # s_mean = []
    # s_thegama = []
    # for i in range(c):
    #     # x1_1=tensor1[0][i][0][0]
    #     # x1_2=tensor1[0][i][0][1]
    #     # x2_1 = tensor2[0][i][0][0]
    #     # x2_2 = tensor2[0][i][0][1]
    #     # thegama = ((x1_1-x1_2)/(x2_1-x2_2))*s_weight[i].detach().cpu()
    #     # mean = (x1_1-thegama*(x2_1-s_bias[i])).detach().cpu()
    #     # s_mean.append(mean)
    #     # s_thegama.append(thegama)
    #     s_mean.append(torch.mean(tensor1[:, i, :, :]))
    #     s_thegama.append(torch.sqrt(torch.square(torch.std(tensor1[:, i, :, :]))+1e-05))
    s_mean = torch.mean(tensor1, dim=[0, 2, 3]) # [ c] - [c,1]
    s_thegama = torch.sqrt(torch.square(torch.std(tensor1, dim=[0, 2, 3])) + 1e-05)  # [ c ]- [c,1,1]

    # s_mean = torch.mean( tensor1, dim=[0,2,3], keepdim=True ) #[ 1,c,1,1 ]
    # s_thegama = torch.sqrt(torch.square(torch.std(tensor1,dim=[0,2,3],keepdim=True))+1e-05) #[ 1,c,1,1 ]

    # s_mean = torch.tile( s_mean, dims=[ b,1,h,wide ] ) #[ b,c,h,w ]
    # s_thegama = torch.tile( s_thegama, dims=[ b,1,h,wide]) #[ b,c,h,w ]
    # s_weight2 = torch.tile( s_weight.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1) , dims=[ b,1,h,wide ] )   #[c] - [1,c] - [1,c,1] - [1,c,1,1] - [ b,c,h,w ]

    n = b * h * width
    grad_next = torch.permute(grad, dims=[1, 0, 2, 3]).reshape(c, -1)  # [b,c,h,w] - [c,b,h,w] - [c, b*h*w]
    # print(f'grad_next.shape:{grad_next.shape}')
    for c_ in range(c):  #
        temp = tensor1[ :, c_, :, : ].reshape(-1,1) # [b,h,w] - [ b*h*w,1 ]
        temp =  -1/( n*s_thegama[c_]) - ( 1/ ((n-1) * s_thegama[c_]**3) ) *  torch.matmul ( ( temp  - s_mean[c_] ), ( temp.T  - s_mean[c_] ) ) #[b*h*w,b*h*w]
        temp = s_weight[c_] * ( temp + torch.diag( torch.ones ( n ) * 1/s_thegama[c_] ) ) #[b*h*w,b*h*w]
        # print( f'c_:{c_}  temp.shape:{temp.shape}' )
        grad_next[c_] = torch.matmul( grad_next[c_], temp )

    # grad_next = torch.permute( grad, dims=[1,0,2,3] ).reshape( c, -1 ) #[b,c,h,w] - [c,b,h,w] - [c, b*h*w]
    # temp = torch.permute( tensor1, dims=[1,0,2,3] ).reshape( c,-1 ) #[b,c,h,w] - [c,b*h*w ]
    # temp = torch.matmul((temp - s_mean).reshape(c,-1,1), ( temp - s_mean).reshape(c,1,-1) )
    # #matmul ([c,b*h*w ] - [c,b*h*w,1 ]  ,   [c,b*h*w ] - [c,1,b*h*w ] ) - [c,b*h*w,b*h*w]
    # temp = -1 / (n * s_thegama) - (1 / ((n - 1) * s_thegama ** 3)) * temp  # [c,b*h*w,b*h*w]
    #
    # temp =( temp + torch.eye(n).unsqueeze(dim=0).tile(c,1,1) * ( 1 / s_thegama ))*s_weight.reshape(-1,1,1)
    # #[c, b*h*w, b*h*w ] + [c, b*h*w, b*h*w ] * [c,1,1] - [c, b*h*w, b*h*w ]
    # grad_next = torch.matmul(grad_next.unsqueeze(dim=1), temp).squeeze(dim=1)
    # #([c, b*h*w]-[c, 1, b*h*w] ) * [c, b*h*w, b*h*w ] - [c, 1, b*h*w] - [c, b*h*w]


    grad_next = torch.permute(grad_next.reshape(c, b, h, width), dims=[1, 0, 2, 3]) # [c, b*h*w] - [c,b,h,w] - [b,c,h,w]
    # del grad, tensor1, s_weight, temp
    # gc.collect()
    # print("finish time")
    return grad_next
    # for c_ in range(c): #[b*h*w]
    #     temp = tensor1[:,c_,:,:].reshape(-1,1) # [b,h,w] - [ b*h*w,1 ]
    #     temp =  -1/( n*s_thegama[c_]) - ( 1/ ((n-1) * s_thegama[c_]**3) ) *  torch.matmul ( ( temp  - s_mean[c_] ), ( temp.T  - s_mean[c_] ) ) #[b*h*w,b*h*w]
    #     temp = s_weight[c_] * ( temp + torch.diag( torch.ones ( temp.shape[0] ) * 1/s_thegama[c_] ) ) #[b*h*w,b*h*w]
    #
    #     grad_next[c_] = torch.matmul( grad_next[c_], temp )
    # matrix = s_weight2 * ((1-1/n)/s_thegama - (1/ ((n-1)*s_thegama**3) ) * (tensor1 - s_mean) ** 2 )

    # matrix = torch.ones_like(tensor1) #[bs,c,,h,w]
    # for b_counter in range(b):
    #     for c_counter in range(c):
    #         for i in range(h):
    #             for j in range(wide):
    #                 matrix[b_counter][c_counter][i][j] = s_weight[c_counter]*(
    #                         (1-1/n)*(1/s_thegama[c_counter])-(1/(n-1))*(1/((s_thegama[c_counter])**3))*((tensor1[b_counter][c_counter][i][j]-s_mean[c_counter])**2)
    #                     )
    # else:
    #     matrix[c_counter*h*wide+i][c_counter*h*wide+j] = s_weight[c_counter]*(-1/n)*(1/(s_thegama[c_counter]))
    # for c_counter in range(c):
    #     temp = s_weight[c] * ((1-1/n)/s_thegama[c] - 1/ ((n-1)*s_thegama[c]**3) ( tensor1[:,c,:,:] - s_mean[c] ) ** 2 )
    #     matrix[:,c,:,:] = temp



def make_advers_samples_normal_model_1(model,test_dataloader,device,epsilon=None,batch_size=10):
    adv_examples = []
    s_adver = []
    for v, data in enumerate(test_dataloader):
        x, y, t = data
        x, y, t = x.to(device), y.long().to(device), t.to(device)
        output = model(x, t)
        for index in range(x.shape[0]):
            init_pred = output[index].argmax(dim=0)
            if init_pred.item() == y[index].item():
                s_adver.append((x[index].cpu(), y[index].cpu(), t[index].cpu()))
        #     continue
        # Calculate the loss
    dataset__ = MyDataset(s_adver)
    adver_dataloader = DataLoader(dataset=dataset__, batch_size=batch_size, shuffle=True)
    for vv, data__ in enumerate(adver_dataloader):
        x_, y_, t_ = data__
        x_, y_, t_ = x_.to(device), y_.to(device), t_.to(device)
        x_.requires_grad = True
        output_ = model(x_, t_)
        loss = torch.nn.functional.cross_entropy(output_, y_)  # [bs]
        # Calculate gradients of model in backward pass
        inputs_grad = torch.autograd.grad(loss, x_)[0]  # [bs, input_dim ]
        perturbed_data = FGSM_attack(x_, epsilon=epsilon, data_grad=inputs_grad)
        adv_examples.append([perturbed_data, y_, t_])

    print("the adver_sample's number:{}".format(len(adv_examples)))
    return adv_examples

def make_advers_samples_normal_model_2(model,test_dataloader,device,epsilon=None):
    adv_examples=[]
    # print(len(test_dataloader))

    for v ,data in enumerate(test_dataloader):
        x,y,t=data
        x, y, t = x.to(device), y.long().to(device),t.to(device)
        x.requires_grad = True
        output = model(x,t)
        # print(output,y)
        init_pred = output.argmax(dim=1)
        # print(init_pred,y)
        # i += torch.eq(output.argmax(dim=1), y).sum().float().item()
        if init_pred == y:
            # continue
            # Calculate the loss


            loss = torch.nn.functional.cross_entropy(output, y)  # [bs]
            # Calculate gradients of model in backward pass
            inputs_grad = torch.autograd.grad(loss, x)[0]  # [bs, input_dim ]
            perturbed_data = FGSM_attack(x, epsilon=epsilon, data_grad=inputs_grad)
            adv_examples.append((perturbed_data, y, t))

    print("the adver_sample's number:{}".format(len(adv_examples)))

    return adv_examples


def construct_W0_cnn(  weight, height, width, batch_size, stride=1, pad=1, height_out=42, width_out=42 ):
    #shape of weight: [c_out,c_in,k,k]
    channel_out, channel_in, kernel_size, _ = weight.shape
    # print( weight )
    # print( weight.shape )
    board = torch.ones([height, width])
    board = F.pad( board, (pad, pad, pad, pad), "constant", 0 )

    W0 = [ [] for _ in range(channel_out) ]

    # p = 0
    for h in range(0, height + 2*pad - kernel_size, stride):
        # p += 1
        for w in range(0, width + 2* pad - kernel_size, stride):
            filter = board[h:h + kernel_size, w:w + kernel_size]
            weight_index = torch.where(filter.reshape(-1) == 1)[0]  # find the weight index on non-padding place

            for c_o in range(channel_out):
                filter2 = [  weight[ c_o, c_i, :, : ].reshape(-1)[ weight_index ].sum() for c_i in range(channel_in) ] #[channel_in]

                W0[c_o].append( torch.Tensor( filter2 ).sum() )


    # print( p,  height - kernel_size, stride,h,w )
    # print( W0[0] )
    # print( len( torch.Tensor(W0[0]) ) )   #[ [hl+1 * wl+1],[hl+1 * wl+1],....,[hl+1 * wl+1] ] #channel_out
    # print(len(torch.Tensor( W0 )),)
    return torch.Tensor( W0 ).reshape(1,-1).tile( batch_size,1 ).reshape( batch_size, channel_out, height_out, width_out )
    # [[hl+1 * wl+1],[hl+1 * wl+1],....,[hl+1 * wl+1]] -- [1, channel_out * hl+1 * wl+1] -- [bs,, channel_out * hl+1 * wl+1] -- [bs,, channel_out , hl+1 , wl+1]