import collections as cls
import torch
import torch.nn.functional as F
def FGSM_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    
    perturbed_image = image + epsilon * sign_data_grad
    
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
 
    return perturbed_image
def make_advers_samples(model,test_dataloader,device,epsilon=None):
    adv_examples=[]
    for v ,data in enumerate(test_dataloader):
        x,y=data
        x, y = x.to(device), y.long().to(device)
        x.requires_grad = True
        output = model(x)
        init_pred = output[-1].argmax(dim=1)
        if init_pred.item() != y.item():
            continue
        # Calculate the loss

        loss = torch.nn.functional.cross_entropy(output[-1], y)  # [bs]
        # Calculate gradients of model in backward pass
        inputs_grad = torch.autograd.grad(loss, x)[0]  # [bs, input_dim ]
        perturbed_data = FGSM_attack(x, epsilon=epsilon, data_grad=inputs_grad)
        adv_examples.append([perturbed_data, y])

    print("the adver_sample's number:{}".format(len(adv_examples)))
    return adv_examples


def make_advers_samples_normal_model(model,test_dataloader,device,epsilon=None,task_number_=None):
    adv_examples=[]
    for v ,data in enumerate(test_dataloader):
        x,y=data
        x, y = x.to(device), y.long().to(device)
        x.requires_grad = True
        output = model(x,task_number_)
        init_pred = output[-1].argmax(dim=1)
        if init_pred.item() != y.item():
            continue
        # Calculate the loss

        loss = torch.nn.functional.cross_entropy(output[-1], y)  # [bs]
        # Calculate gradients of model in backward pass
        inputs_grad = torch.autograd.grad(loss, x)[0]  # [bs, input_dim ]
        perturbed_data = FGSM_attack(x, epsilon=epsilon, data_grad=inputs_grad)
        adv_examples.append([perturbed_data, y])

    print("the adver_sample's number:{}".format(len(adv_examples)))
    return adv_examples

# def maps_reshape( maps, padding, filter_size, stride  ):
#
#     bs, channel, height, weight = maps.shape
#     pad = torch.zeros( [ bs, channel, height + padding*2, weight + padding*2 ] )
#     maps = pad[ :, :, padding:padding + height, padding:padding + weight ]
#     bs, channel, height, weight = maps.shape
#
#     temp = []
#     for map in maps: #[channel, height, weight]  # dim of batch_size
#         for h in range( 0, height - filter_size, stride ):
#             for w in range( 0, weight - filter_size, stride ):
#                 temp.append( map[ :, h:h+filter_size, w:w+filter_size ].reshape(-1) ) #[channel*k*k]
#
#     new_map = torch.cat( temp ) #[ bs*h0*w0, channel*k*k ]
#     return new_map
def maps_reshape( maps,filter_size  ):
    temp=[]
    bs, channel, height, wide = maps.shape
    for i in range(bs):
        # for j in range(int(height-filter_size/stride)):
        #     for k in range(int(wide-filter_size/stride)):
        #         temp.append(maps[i,:,j*stride:j*stride+filter_size,k*stride:k*stride+filter_size].flatten().unsqueeze(dim=0))
        for j in range(height-filter_size):
            for k in range(wide-filter_size):
                temp.append(maps[i,:,j:j+filter_size,k:k+filter_size].flatten().unsqueeze(dim=0))
    new_temp=torch.cat(temp)
    return new_temp

def maps_reshape_1( maps, pad, kernel_size, stride  ):

    # bs, channel, height, weight = maps.shape
    maps = F.pad(maps, ( pad, pad,pad, pad ), "constant", 0)
    bs, channel, height, weight = maps.shape

    temp = []
    for map in maps: #[channel, height, weight]  # dim of batch_size
        for h in range( 0, height - kernel_size, stride ):
            for w in range( 0, weight - kernel_size, stride ):
                temp.append( map[ :, h:h+kernel_size, w:w+kernel_size ].reshape(-1) ) #[channel*k*k]

    new_map = torch.cat( temp ) #[ bs*h0*w0, channel*k*k ]
    return new_map

def maps_reshape_strade(maps, filter_size,stride=2.):
    temp = []
    bs, channel, height, wide = maps.shape
    for i in range(bs):
        for j in range(int((height-filter_size)/stride)):
            for k in range(int((wide-filter_size)/stride)):
                temp.append(maps[i,:,j*stride:j*stride+filter_size,k*stride:k*stride+filter_size].flatten().unsqueeze(dim=0))
    new_temp = torch.cat(temp)
    return new_temp
def evaluate_acc_multihead(model,dataloader,device,task_number=None):
    model.eval()
    correct_numbers=0
    for v, data in enumerate(dataloader):
        x, y = data
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
        correct_numbers += torch.eq(logits.argmax(dim=1), y).sum().float().item()
    accuracy = correct_numbers / len(dataloader.dataset)
    return accuracy
def partpool(m):
    batch,c,h,w=m.shape
    new_m=torch.rand(batch,c,int(h/2),int(w/2))
    for i in range(batch):
        for j in range(c):
            for _ in range(int(h/2)):
                for __ in range(int(w/2)):
                    new_m[batch][c][_][__]=0.25*(m[batch][c][_*2][__*2]+m[batch][c][_*2][__*2+1]+m[batch][c][_*2+1][__*2]+m[batch][c][_*2+1][__*2+1])
    return new_m
def get_m_thegama(w,tensor1,tensor2,gramma,beta):
    b,c,h,wide=tensor1.shape
    for i in range(c):
        x1_1=tensor1[0][i][0][0]
        x1_2=tensor1[0][i][0][1]
        x2_1 = tensor2[0][i][0][0]
        x2_2 = tensor2[0][i][0][1]
        thegama=(x1_1-x1_2)/(x2_1-x2_2+1e-8)
        mean=x1_1-thegama*x2_1
        w[:,i,:,:]=(w[:,i,:,:]/(thegama+1e-8))*(1-mean/(tensor1[:,i,:,:]+1e-8))*gramma[i]-w[:,i,:,:]*beta[i]/(tensor1[:,i,:,:]+1e-8)

    return w
def get_m_thegama1(w,tensor1,tensor2,tensor3):
    b,c,h,wide=tensor1.shape
    for i in range(c):
        x1_1=tensor1[0][i][0][0]
        x1_2=tensor1[0][i][0][1]
        x2_1 = tensor2[0][i][0][0]
        x2_2 = tensor2[0][i][0][1]
        thegama=(x1_1-x1_2)/(x2_1-x2_2+1e-8)
        mean=x1_1-thegama*x2_1
        # w[:,i,:,:]=(w[:,i,:,:]/(thegama+1e-8))*(1-mean/(tensor1[:,i,:,:]+1e-8))
        w[:, i, :, :] = (w[:, i, :, :] / (thegama + 1e-8))*tensor3[i]
    return w
def get_relu(w,tensor1):
    tensor1[tensor1 < 0] = 0
    tensor1[tensor1 > 0] = 1
    w = torch.mul(w, tensor1)
    return w
def make_orthogonality_pool(pool1,pool2,c,device):
    Pool=cls.defaultdict(list)
    for index in pool1:
        Pool[index]=torch.cat([pool1[index],pool2[index]],dim=0)
    for n_n, pool_dim in Pool.items():

        pool_dim = pool_dim.cpu()
        # if pool_dim.shape[0]>=max(list(model.parameters())[n_n+9].shape[0],list(model.parameters())[n_n+9].shape[1]):
        whole_ = torch.sum(torch.abs(pool_dim) ** 2)
        U, S, Vh = torch.linalg.svd(pool_dim, full_matrices=True)
        V = Vh.T
        temp_re = torch.zeros_like(pool_dim)
        for k_K in range(S.shape[0]):
            u, s, v = U[:, k_K:k_K + 1], S[k_K], V[:, k_K:k_K + 1]
            temp_re += s * torch.matmul(u, v.T)
            re = torch.square(torch.linalg.matrix_norm(temp_re))
            # re=torch.square(torch.linalg.matrix_norm(temp_re))
            if re >= whole_ * c[n_n]:
                print(k_K)
                break
        pool_dim = V[:, :k_K + 1].T
        Pool[n_n] = pool_dim.to(device)
    return Pool
    print(f'number of directions in pool is {[pc.shape[0] for pc in Pool.values()]} ')

def gradient_cnn( x0, x1 ):

    # bs, c0, h0, w0 = x0.shape
    bs, c, h, w = x1.shape
    # temp = torch.zeros( [bs, c0*h0*w0, c * h * w] )  # numb of elements in one column(c * h * w) of temp equals that in one sample in x1
    temp = torch.zeros([bs,  c * h * w])
    for b in range(bs): #x1.shape = [bs, c, h ,w]
        # x1 = x1.sum( dim=0 )  # [bs, c, h ,w] -- [ c, h ,w]
        for col, w in enumerate( x1[b].reshape(-1) ):  # [ c*h*w ]
            grad = torch.autograd.grad( w, x0, create_graph=True )[0]
            # w must be scalar value;  grad.shape = sample.shape
            # grad = grad.reshape(bs, -1).sum(dim=-1)  # [bs0, c0, h0 ,w0] -- [bs0, c0*h0*w0] -- [bs0]
            grad = grad.reshape(bs, -1).detach().cpu().sum( dim=1 )  # [bs,c0, h0 ,w0] -- [ bs]
            temp[:, col] += grad

    return temp #[bs, c * h * w]



def gradient_fc( x0, x1 ):

    # bs, f0 = x0.shape
    bs, f1 = x1.shape
    temp = torch.zeros( [bs, f1] )  # numb of elements in one column(c * h * w) of temp equals that in one sample in x1
    for b in range(bs):
        for col, w in enumerate( x1[b] ):  # [f1 ]
            grad = torch.autograd.grad(w, x0, create_graph=True)[0]  #[f0]
            # w must be scalar value;  grad.shape = sample.shape
            grad = grad.reshape(bs,-1).detach().cpu().sum(dim=1)
            temp[:,  col] += grad

    return temp #[bs,f0, f1]




def vector_extend( input_grad,  channel_in, height, width, stride=1, kernel_size=3, pad=1, ):
    #height, weight of the input with no-padding
    board = torch.ones( [ height, width ] )
    board = F.pad( board, ( pad, pad,pad, pad ), "constant", 0)

    index = []

    for h in range( 0, height - kernel_size, stride ):
        for w in range( 0, width - kernel_size, stride ):
            filter = board[ h:h+kernel_size, w:w+kernel_size ]
            weight_index = torch.where( filter.reshape(-1) == 1 )[0] #find the weight index on non-padding place

            board__ = board.clone()
            board__[ h:h+kernel_size, w:w+kernel_size  ] += 1 #helight (=2) the place of non-padding in in padding input
            board__ = board__[ pad:pad+height, pad:pad+width ] #find the place indexs of non-padding in no-padding input
            place_index = torch.where( board__.reshape(-1) == 2 )[0]

            weight_index = torch.cat( [weight_index + (kernel_size * kernel_size) * i for i in range(channel_in) ] )
            place_index = torch.cat([place_index + (height * width) * i for i in range(channel_in)])
            #on subsequent channel, the opeartion is repeat and the index = index + (height * weight)

            index.append( [place_index, weight_index] )

    input_grad_extend = []
    for row in input_grad:
        for place_index, weight_index in index:

            row_rectify = torch.zeros( channel_in * kernel_size * kernel_size )
            row_rectify[ weight_index ] = row[ place_index ]
            input_grad_extend .append( row_rectify )

    input_grad_extend = torch.vstack( input_grad_extend )

    return input_grad_extend