import torch
import numpy as np
import os

from tqdm import tqdm


from torch.autograd.gradcheck import zero_gradients
from torch.autograd import Variable

from cvxopt import matrix, solvers


def loss_fn(x1,x2):
#     loss=torch.norm(x1-x2)
    loss=torch.nn.MSELoss(reduction='none')
    return torch.sum(loss(x1,x2),dim=1)

def cone_greedy_second_relu(args,agent_num,num_outputs,v1,v2,L1,k1,L2,k2,dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict):

    solvers.options['show_progress'] = False

    print('Initial distance: %s' % np.linalg.norm(v1-v2))
    init_delta=(v2-v1)/2
    v1_start=v1+init_delta
    v2_start=v2-init_delta

    if 'svd' in args.alg_type:
        print('Using SVD for dimension lowering')
        assert args.analysis_norm=='l2'
        U1, Sigma1, VT1 = np.linalg.svd(L1,full_matrices=False)
        L1_mod=np.dot(U1,np.diag(Sigma1))
        n1=L1_mod.shape[1]
        n2=L1_mod.shape[0]
        n3=L2.shape[0]
    else:
        n1=L1.shape[1]
        n2=L1.shape[0]
        n3=L2.shape[0]

    dims={'l':2*n2+2*n3,'q':[n1+1,n1+1],'s':[]}

    cone_result_dict={}
    perturb_store={}

    s1=np.zeros(L1.shape[0])
    s2=np.zeros(L2.shape[0])
    
    L1_out=np.dot(L1,v1_start)+k1
    s1[np.where(L1_out>0)]=1.0
    D1=np.diag(s1)
    
    L2_out=np.dot(L2,np.dot(D1,L1_out))+k2
    s2[np.where(L2_out>0)]=1.0
    D2=np.diag(s2)
    
    relu_condition_1=True
    relu_condition_2=True
    loss_condition=True
    perturb_norm=[]
    num_iters=0
    
    G=np.zeros((2*n3+2*n2+1+n1+1+n1,1+n1+n1))
    
    G[2*n3+2*n2:2*n3+2*n2+1,0:1]=-1.0

    G[2*n3+2*n2+1:2*n3+2*n2+1+n1,1:1+n1]=np.eye(n1)

    G[2*n3+2*n2+1+n1:2*n3+2*n2+1+n1+1,0:1]=-1.0

    G[2*n3+2*n2+1+n1+1:2*n3+2*n2+1+n1+1+n1,1+n1:1+n1+n1]=np.eye(n1)
    
    h=np.zeros(2*n3+2*n2+1+n1+1+n1)
    
    # Construct c
    c=np.zeros(2*n1+1)
    c[0]=1

    solvers.options['feastol'] = 1e-7
    solvers.options['abstol'] = 1e-7
    solvers.options['reltol'] = 1e-6
    while (relu_condition_1 is True or relu_condition_2 is True) and loss_condition is True:
        num_iters+=1
        print('Running iter %s for agent %s' % (num_iters,agent_num))
        D1=np.diag(s1)
        D2=np.diag(s2)

        # Construct G
        
        relu2_sign=2*D2-np.eye(n3)
        relu1_sign=2*D1-np.eye(n2)
        
        
        if 'svd' in args.alg_type:
            G[:n3,1:1+n1]=-1.0*(relu2_sign)@(L2@(D1@L1_mod))
            G[n3:n3+n2,1:1+n1]=-1.0*(relu1_sign)@L1_mod
            
            G[n3+n2:n3+n2+n3,1+n1:1+2*n1]=-1.0*(relu2_sign)@(L2@(D1@L1_mod))
            G[n3+n2+n3:n3+n2+n3+n2,1+n1:1+2*n1]=-1.0*(relu1_sign)@L1_mod

        else:
            G[:n3,1:1+n1]=-1.0*(relu2_sign)@(L2@(D1@L1))
            G[n3:n3+n2,1:1+n1]=-1.0*(relu1_sign)@L1
            
            G[n3+n2:n3+n2+n3,1+n1:1+2*n1]=-1.0*(relu2_sign)@(L2@(D1@L1))
            G[n3+n2+n3:n3+n2+n3+n2,1+n1:1+2*n1]=-1.0*(relu1_sign)@L1

        # Construct h

        h[:n3]=relu2_sign@(L2@(D1@(L1@v1+k1))+k2)
        h[n3:n3+n2]=relu1_sign@(L1@v1+k1)
        
        h[n3+n2:n3+n2+n3]=relu2_sign@(L2@(D1@(L1@v2+k1))+k2)
        h[n3+n2+n3:n3+n2+n3+n2]=relu1_sign@(L1@v2+k1)

        # Construct A
        if 'svd' in args.alg_type:
            A=A=D2@L2@D1@L1_mod
        else:
            A=D2@L2@D1@L1
        
        U,Sigma,VT=np.linalg.svd(A)
        
        slice_idx=np.where(Sigma>1e-8)
        
        VT_slice=VT[slice_idx]
        
        A_pre=np.hstack((VT_slice,-1.0*VT_slice))
        
        A_final=np.hstack((np.zeros((len(slice_idx[0]),1)),A_pre))

        # Construct b

        if 'svd' in args.alg_type:
            b=np.dot(VT_slice,np.dot(VT1,v2)-np.dot(VT1,v1))
        else:
            b=np.dot(VT_slice,v2-v1)

        # Convert to cvxopt style
        c_m=matrix(c)
        G_m=matrix(G)
        h_m=matrix(h)
        A_m=matrix(A_final)
        b_m=matrix(b)
        
        # Solve!
        sol=solvers.conelp(c_m,G_m,h_m,dims,A_m,b_m)
        
        curr_loss=np.array(sol['x'])[0][0]
        if num_iters>1:
            if curr_loss<perturb_norm[-1]:
                loss_condition=True
            else:
                loss_condition=False
        
        perturb_norm.append(curr_loss)
        
        z12=np.array(sol['z'][:n3])
        z11=np.array(sol['z'][n3:n3+n2])
        
        z22=np.array(sol['z'][n3+n2:n3+n2+n3])
        z21=np.array(sol['z'][n3+n2+n3:n3+n2+n3])
        
        # ReLU 2 check
        z12_nz=set(np.where(z12>1e-5)[0])
        z22_nz=set(np.where(z22>1e-5)[0])
        
        # ReLU 1 check
        z11_nz=set(np.where(z11>1e-5)[0])
        z21_nz=set(np.where(z21>1e-5)[0])
        
        print('Relu 1 nz: %s, %s' % (z11_nz,z21_nz))
        print('Relu 2 nz: %s, %s' % (z12_nz,z22_nz))
        
        intersect_idx_1=z11_nz.intersection(z21_nz)
        intersect_idx_2=z12_nz.intersection(z22_nz)
        
        if len(intersect_idx_1)>0:
            relu_condition_1=True
            for item in intersect_idx_1:
                s1[item]=1.0-s1[item]
        else:
            relu_condition_1=False
            
        if len(intersect_idx_2)>0:
            relu_condition_2=True
            for item in intersect_idx_2:
                s2[item]=1.0-s2[item]
        else:
            relu_condition_2=False
            
        # print('ReLU 1 Condition found to be %s' % relu_condition_1)
        # print('ReLU 2 Condition found to be %s' % relu_condition_2)
            
        # print('Relu Condition found to be %s' % relu_condition)
        # print('Loss condition found to be %s' % loss_condition)

    delta_1=np.array(sol['x'])[1:n1+1].reshape(n1)
    delta_2=np.array(sol['x'])[n1+1:].reshape(n1)

    if 'svd' in args.alg_type:
        delta_1=np.dot(VT1.T,delta_1)
        delta_2=np.dot(VT1.T,delta_2)

    v1_end=v1+delta_1
    v2_end=v2+delta_2

    out_end_1=D2@(L2@(D1@(L1@v1_end+k1))+k2)

    out_end_2=D2@(L2@(D1@(L1@v2_end+k1))+k2)

    # np.dot(L2,(np.dot(L1,v1_end)+k1))+k2
    out_end_1[np.where(out_end_1<1e-12)]=0.0

    # out_end_2=np.dot(L2,(np.dot(L1,v2_end)+k1))+k2
    out_end_2[np.where(out_end_2<1e-12)]=0.0

    cone_result_dict['start_dists']=source_dists
    cone_result_dict['perturb_magnitudes']=float(np.array(sol['x'])[0])
    cone_result_dict['deeper_exact_dists']=float(np.linalg.norm(out_end_1-out_end_2))
    cone_result_dict['num_iters']=num_iters
    
    perturb_store[str(num_outputs-4)]=np.hstack((delta_1,delta_2))
    
    curr_key=str(args.start_idx+agent_num)+ ',(' + str(source_idx[0])+','+str(source_idx[1]) + ')'
    cone_all_result_dict[curr_key]=cone_result_dict
    cone_all_perturb_dict[curr_key]=perturb_store
    
    print('Completed run %s' % (args.start_idx+agent_num))

def cone_greedy_second_linear(args,agent_num,num_outputs,v1,v2,L1,k1,L2,k2,dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict):

    solvers.options['show_progress'] = False
    print('Initial distance: %s' % np.linalg.norm(v1-v2))

    if 'svd' in args.alg_type:
        print('Using SVD for dimension lowering')
        assert args.analysis_norm=='l2'
        U1, Sigma1, VT1 = np.linalg.svd(L1,full_matrices=False)
        L1_mod=np.dot(U1,np.diag(Sigma1))
        n1=L1_mod.shape[1]
        n2=L1_mod.shape[0]
    else:
        n1=L1.shape[1]
        n2=L1.shape[0]

    dims={'l':2*n2,'q':[n1+1,n1+1],'s':[]}

    cone_result_dict={}
    perturb_store={}

    s1=np.zeros(L1.shape[0])
    init_delta=(v2-v1)/2
    v1_start=v1+init_delta
    v2_start=v2-init_delta
    L1_out=np.dot(L1,v1_start)+k1
    s1[np.where(L1_out>1e-12)]=1.0
    relu_condition=True
    loss_condition=True
    perturb_norm=[]
    num_iters=0

    # Constant matrix construction
    G=np.zeros((2*n2+1+n1+1+n1,1+n1+n1))
    
    G[2*n2:2*n2+1,0:1]=-1.0

    G[2*n2+1:2*n2+1+n1,1:1+n1]=np.eye(n1)

    G[2*n2+1+n1:2*n2+1+n1+1,0:1]=-1.0

    G[2*n2+1+n1+1:2*n2+1+n1+1+n1,1+n1:1+n1+n1]=np.eye(n1)
    
    h=np.zeros(2*n2+1+n1+1+n1)
    
    # Construct c
    c=np.zeros(2*n1+1)
    c[0]=1

    solvers.options['feastol'] = 1e-7
    solvers.options['abstol'] = 1e-7
    solvers.options['reltol'] = 1e-6
    while relu_condition is True and loss_condition is True:
        num_iters+=1
        print('Running iter %s for agent %s' % (num_iters,agent_num))
        D1=np.diag(s1)

        # Construct G

        if 'svd' in args.alg_type:
            G[:n2,1:1+n1]=-1.0*(np.dot(2*D1-np.eye(n2),L1_mod))

            G[n2:2*n2,1+n1:1+2*n1]=-1.0*(np.dot(2*D1-np.eye(n2),L1_mod))
        else:
            G[:n2,1:1+n1]=-1.0*(np.dot(2*D1-np.eye(n2),L1))

            G[n2:2*n2,1+n1:1+2*n1]=-1.0*(np.dot(2*D1-np.eye(n2),L1))

        # Construct h

        h[:n2]=np.dot(2*D1-np.eye(n2),(np.dot(L1,v1)+k1))

        h[n2:2*n2]=np.dot(2*D1-np.eye(n2),(np.dot(L1,v2)+k1))

        # Construct A
        if 'svd' in args.alg_type:
            A=np.dot(L2,np.dot(D1,L1_mod))
        else:
            A=np.dot(L2,np.dot(D1,L1))
        
        U,Sigma,VT=np.linalg.svd(A)
        
        slice_idx=np.where(Sigma>1e-8)
        
        VT_slice=VT[slice_idx]
        
        A_pre=np.hstack((VT_slice,-1.0*VT_slice))
        
        A_final=np.hstack((np.zeros((len(slice_idx[0]),1)),A_pre))

        # Construct b

        if 'svd' in args.alg_type:
            b=np.dot(VT_slice,np.dot(VT1,v2)-np.dot(VT1,v1))
        else:
            b=np.dot(VT_slice,v2-v1)

        # Convert to cvxopt style
        c_m=matrix(c)
        G_m=matrix(G)
        h_m=matrix(h)
        A_m=matrix(A_final)
        b_m=matrix(b)
        
        # Solve!
        sol=solvers.conelp(c_m,G_m,h_m,dims,A_m,b_m)
        
        curr_loss=np.array(sol['x'])[0][0]
        if num_iters>1:
            if curr_loss<perturb_norm[-1]:
                loss_condition=True
            else:
                loss_condition=False
        perturb_norm.append(curr_loss)
        
        z1=np.array(sol['z'][:n2])
        z2=np.array(sol['z'][n2:2*n2])
        
        z1_nz=set(np.where(z1>1e-5)[0])
        z2_nz=set(np.where(z2>1e-5)[0])
        
        # print(z1_nz,z2_nz)
        
        intersect_idx=z1_nz.intersection(z2_nz)
        
        if len(intersect_idx)>0:
            relu_condition=True
            for item in intersect_idx:
                s1[item]=1.0-s1[item]
        else:
            relu_condition=False
            
        # print('Relu Condition found to be %s' % relu_condition)
        # print('Loss condition found to be %s' % loss_condition)

    delta_1=np.array(sol['x'])[1:n1+1].reshape(n1)
    delta_2=np.array(sol['x'])[n1+1:].reshape(n1)

    if 'svd' in args.alg_type:
        delta_1=np.dot(VT1.T,delta_1)
        delta_2=np.dot(VT1.T,delta_2)

    v1_end=v1+delta_1
    v2_end=v2+delta_2

    out_end_1=L2@(D1@(L1@v1_end+k1))+k2

    out_end_2=L2@(D1@(L1@v2_end+k1))+k2

    # np.dot(L2,(np.dot(L1,v1_end)+k1))+k2
    out_end_1[np.where(out_end_1<1e-12)]=0.0

    # out_end_2=np.dot(L2,(np.dot(L1,v2_end)+k1))+k2
    out_end_2[np.where(out_end_2<1e-12)]=0.0

    cone_result_dict['start_dists']=source_dists
    cone_result_dict['perturb_magnitudes']=float(np.array(sol['x'])[0])
    cone_result_dict['deeper_exact_dists']=float(np.linalg.norm(out_end_1-out_end_2))
    cone_result_dict['num_iters']=num_iters
    
    perturb_store[str(num_outputs-3)]=np.hstack((delta_1,delta_2))
    
    curr_key=str(args.start_idx+agent_num)+ ',(' + str(source_idx[0])+','+str(source_idx[1]) + ')'
    cone_all_result_dict[curr_key]=cone_result_dict
    cone_all_perturb_dict[curr_key]=perturb_store
    
    print('Completed run %s' % (args.start_idx+agent_num))



def cone_greedy_first_relu(args,agent_num,num_outputs,v1,v2,L1,k1,dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict):

    solvers.options['show_progress'] = False
    print('Initial distance: %s' % np.linalg.norm(v1-v2))

    if 'svd' in args.alg_type:
        print('Using SVD for dimension lowering')
        assert args.analysis_norm=='l2'
        U1, Sigma1, VT1 = np.linalg.svd(L1,full_matrices=False)
        L1_mod=np.dot(U1,np.diag(Sigma1))
        n1=L1_mod.shape[1]
        n2=L1_mod.shape[0]
    else:
        n1=L1.shape[1]
        n2=L1.shape[0]

    dims={'l':2*n2,'q':[n1+1,n1+1],'s':[]}

    cone_result_dict={}
    perturb_store={}

    s=np.zeros(L1.shape[0])
    init_delta=(v2-v1)/2
    v1_start=v1+init_delta
    v2_start=v2-init_delta
    L1_out=np.dot(L1,v1_start)+k1
    s[np.where(L1_out>1e-12)]=1.0
    relu_condition=True
    loss_condition=True
    perturb_norm=[]
    num_iters=0

    # Constant matrix construction
    G=np.zeros((2*n2+1+n1+1+n1,1+n1+n1))
    
    G[2*n2:2*n2+1,0:1]=-1.0

    G[2*n2+1:2*n2+1+n1,1:1+n1]=np.eye(n1)

    G[2*n2+1+n1:2*n2+1+n1+1,0:1]=-1.0

    G[2*n2+1+n1+1:2*n2+1+n1+1+n1,1+n1:1+n1+n1]=np.eye(n1)
    
    h=np.zeros(2*n2+1+n1+1+n1)
    
    # Construct c
    c=np.zeros(2*n1+1)
    c[0]=1

    solvers.options['feastol'] = 1e-7
    solvers.options['abstol'] = 1e-7
    solvers.options['reltol'] = 1e-6
    while relu_condition is True and loss_condition is True:
        num_iters+=1
        print('Running iter %s for agent %s' % (num_iters,agent_num))
        D=np.diag(s)
        # D2=D[np.where(s==1.0)[0],:]

        # Construct G

        if 'svd' in args.alg_type:
            G[:n2,1:1+n1]=-1.0*(np.dot(2*D-np.eye(n2),L1_mod))

            G[n2:2*n2,1+n1:1+2*n1]=-1.0*(np.dot(2*D-np.eye(n2),L1_mod))
        else:
            G[:n2,1:1+n1]=-1.0*(np.dot(2*D-np.eye(n2),L1))

            G[n2:2*n2,1+n1:1+2*n1]=-1.0*(np.dot(2*D-np.eye(n2),L1))

        # Construct h

        h[:n2]=np.dot(2*D-np.eye(n2),(np.dot(L1,v1)+k1))

        h[n2:2*n2]=np.dot(2*D-np.eye(n2),(np.dot(L1,v2)+k1))

        # Construct A
        if 'svd' in args.alg_type:
            A=np.dot(D,L1_mod)
        else:
            A=np.dot(D,L1)
        
        U,Sigma,VT=np.linalg.svd(A)
        
        slice_idx=np.where(Sigma>1e-8)
        
        VT_slice=VT[slice_idx]
        
        A_pre=np.hstack((VT_slice,-1.0*VT_slice))
        
        A_final=np.hstack((np.zeros((len(slice_idx[0]),1)),A_pre))

        # Construct b

        if 'svd' in args.alg_type:
            b=np.dot(VT_slice,np.dot(VT1,v2)-np.dot(VT1,v1))
        else:
            b=np.dot(VT_slice,v2-v1)

        # Construct A
        # if 'svd' in args.alg_type:
        #     A=np.hstack((np.dot(D2,L1_mod),-1.0*np.dot(D2,L1_mod)))
        # else:
        #     A=np.hstack((np.dot(D2,L1),-1.0*np.dot(D2,L1)))

        # A=np.hstack((np.zeros((D2.shape[0],1)),A))

        # # Construct b

        # b=np.dot(D2,np.dot(L1,v2-v1))

        # Convert to cvxopt style
        c_m=matrix(c)
        G_m=matrix(G)
        h_m=matrix(h)
        A_m=matrix(A_final)
        b_m=matrix(b)
        
        # Solve!
        sol=solvers.conelp(c_m,G_m,h_m,dims,A_m,b_m)
        
        curr_loss=np.array(sol['x'])[0][0]
        if num_iters>1:
            if curr_loss<perturb_norm[-1]:
                loss_condition=True
            else:
                loss_condition=False
        perturb_norm.append(curr_loss)
        
        z1=np.array(sol['z'][:n2])
        z2=np.array(sol['z'][n2:2*n2])
        
        z1_nz=set(np.where(z1>1e-5)[0])
        z2_nz=set(np.where(z2>1e-5)[0])
        
        # print(z1_nz,z2_nz)
        
        intersect_idx=z1_nz.intersection(z2_nz)
        
        if len(intersect_idx)>0:
            relu_condition=True
            for item in intersect_idx:
                s[item]=1.0-s[item]
        else:
            relu_condition=False
            
        # print('Relu Condition found to be %s' % relu_condition)
        # print('Loss condition found to be %s' % loss_condition)

    delta_1=np.array(sol['x'])[1:n1+1].reshape(n1)
    delta_2=np.array(sol['x'])[n1+1:].reshape(n1)

    if 'svd' in args.alg_type:
        delta_1=np.dot(VT1.T,delta_1)
        delta_2=np.dot(VT1.T,delta_2)

    v1_end=v1+delta_1
    v2_end=v2+delta_2

    out_end_1=np.dot(L1,v1_end)+k1
    out_end_1[np.where(out_end_1<1e-12)]=0.0

    out_end_2=np.dot(L1,v2_end)+k1
    out_end_2[np.where(out_end_2<1e-12)]=0.0

    cone_result_dict['start_dists']=source_dists
    cone_result_dict['perturb_magnitudes']=float(np.array(sol['x'])[0])
    cone_result_dict['deeper_exact_dists']=float(np.linalg.norm(out_end_1-out_end_2))
    cone_result_dict['num_iters']=num_iters
    
    perturb_store[str(num_outputs-2)]=np.hstack((delta_1,delta_2))
    
    curr_key=str(args.start_idx+agent_num)+ ',(' + str(source_idx[0])+','+str(source_idx[1]) + ')'
    cone_all_result_dict[curr_key]=cone_result_dict
    cone_all_perturb_dict[curr_key]=perturb_store
    
    print('Completed run %s' % (args.start_idx+agent_num))

def linear_exact_intersection(args,A,proj_mat,agent_num,x_curr,num_outputs,source_idx,source_dists,linear_all_result_dict,linear_all_perturb_dict):

    linear_result_dict={}
    perturb_store={}

#     eigs,eigvs=np.linalg.eig(np.dot(A,A.T))

    # Method 1
#     for j in range(args.closest_n):
#         x1=x_source[j:j+1].detach().cpu().numpy().reshape(784)
#         x2=x_source[args.closest_n+j:args.closest_n+j+1].detach().cpu().numpy().reshape(784)
#         min_dist=0.5*np.dot(A.T,np.dot(np.linalg.inv(np.dot(A,A.T)),np.dot(A,x1-x2)))
#         t_vec=0.5*np.dot(A.T,np.dot(np.linalg.inv(np.dot(A,A.T)),np.dot(A,x1-x2)))
#         t_list.append(np.linalg.norm(t_vec))
#     #     print(np.linalg.norm(x1-x2)/2)
#         x1_tilde=torch.from_numpy((x1-t_vec).reshape((1,1,28,28)))
#         x2_tilde=torch.from_numpy((x2+t_vec).reshape((1,1,28,28)))
#         deeper_layer_diff=net(x1_tilde.cuda())[4]-net(x2_tilde.cuda())[4]
#         deeper_layer_dists_exact.append(torch.norm(deeper_layer_diff))
        
    # Method 2
#     x1=x_curr[0:1].detach().cpu().numpy().reshape(784)
#     x2=x_curr[1:2].detach().cpu().numpy().reshape(784)
#     thing_to_invert=np.linalg.solve(np.dot(A,A.T),np.dot(A,x1-x2))
#     t_vec=0.5*np.dot(A.T,thing_to_invert)
#     x1_tilde=torch.from_numpy((x1-t_vec).reshape((1,1,28,28)))
#     x2_tilde=torch.from_numpy((x2+t_vec).reshape((1,1,28,28)))
#     deeper_layer_diff=net(x1_tilde.cuda())[num_outputs-1]-net(x2_tilde.cuda())[num_outputs-1]
    
    # Method 3
    x1=x_curr[0:1].reshape(784)
    x2=x_curr[1:2].reshape(784)
    t_vec=0.5*np.dot(proj_mat,x1-x2)
    x1_tilde=x1-t_vec
    x2_tilde=x2+t_vec
    deeper_layer_diff=np.dot(A,x1_tilde)-np.dot(A,x2_tilde)

    linear_result_dict['start_dists']=source_dists
    linear_result_dict['perturb_magnitudes']=float(np.linalg.norm(t_vec))
    linear_result_dict['linear_exact_dists']=float(np.linalg.norm(deeper_layer_diff))
    
    perturb_store[str(num_outputs-1)]=t_vec
    
    curr_key=str(args.start_idx+agent_num)+ ',(' + str(source_idx[0])+','+str(source_idx[1]) + ')'
    linear_all_result_dict[curr_key]=linear_result_dict
    linear_all_perturb_dict[curr_key]=perturb_store
    
    print('Completed run %s' % (args.start_idx+agent_num))

    return


def standard_pgd_intersection(args,x_source,net,n_closest_distances):
    print('Computing intersections for %s closest samples using standard PGD' % args.closest_n)
    result_dict = {}
    end_dists_all=[]
    start_layer_dists_all=[]
    end_layer_dists_all=[]

    index_1=0
    index_2=1

    input_eps=args.input_eps

    n_steps=1000
    eps_step=(input_eps*2.5)/n_steps
    clip_min=0.0
    clip_max=1.0
    # y_var = Variable(y_curr, requires_grad= False)
    for k in range(net.n_outputs-1):
        print('Optimization for output type %s' % k)
        end_dists=[]
        start_layer_dists=[]
        end_layer_dists=[]
        for j in range(args.closest_n):
            print('Running for pair of examples %s' % j)
            x_curr=torch.cat((x_source[j:j+1],x_source[args.closest_n+j:args.closest_n+j+1]))
            x_var = Variable(x_curr, requires_grad= True)
            for i in range(n_steps):
                zero_gradients(x_var)
                output_list_1 =net.forward(x_var[index_1:index_1+1])
                output_list_2 =net.forward(x_var[index_2:index_2+1])
        #         o1_1,o2_1,o3_1,o4_1,o5_1 =net.forward(x_var[:100])
        #         o1_2,o2_2,o3_2,o4_2,o5_2 =net.forward(x_var[100:])
            #     print('l2_dists: %s' % torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]))
        #         print('l2_dists: %s' % torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]))
                loss=loss_fn(output_list_1[k+1],output_list_2[k+1])
                if i==0:
                    start_layer_dists.append(np.sqrt(loss.item()))
                # print('Loss: %s' % loss)
                loss.backward()
                raw_grad = x_var.grad.data
                raw_grad_1 = raw_grad[index_1:index_1+1]
                raw_grad_2 = raw_grad[index_2:index_2+1]
                grad_norm_1 = torch.max(
                       raw_grad_1.view(raw_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                grad_norm_2 = torch.max(
                       raw_grad_2.view(raw_grad_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                if len(x_curr.size())==2:
                    grad_dir_1 = raw_grad_1/grad_norm_1.view(raw_grad_1.size(0),1)
                    grad_dir_2 = raw_grad_2/grad_norm_2.view(raw_grad_2.size(0),1)
                else:
                    grad_dir_1 = raw_grad_1/grad_norm_1.view(raw_grad_1.size(0),1,1,1)
                    grad_dir_2 = raw_grad_2/grad_norm_2.view(raw_grad_2.size(0),1,1,1)
                adv_temp_1 = x_var[index_1:index_1+1].data + -1 * eps_step * grad_dir_1
                adv_temp_2 = x_var[index_2:index_2+1].data + -1 * eps_step * grad_dir_2
                # Clipping total perturbation
                total_grad_1 = adv_temp_1 - x_curr[index_1:index_1+1]
                total_grad_2 = adv_temp_2 - x_curr[index_2:index_2+1]
                total_grad_norm_1 = torch.max(
                       total_grad_1.view(total_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                total_grad_norm_2 = torch.max(
                       total_grad_2.view(total_grad_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                if len(x_curr.size())==2:
                    total_grad_dir_1 = total_grad_1/total_grad_norm_1.view(total_grad_1.size(0),1)
                    total_grad_dir_2 = total_grad_2/total_grad_norm_2.view(total_grad_2.size(0),1)
                else:
                    total_grad_dir_1 = total_grad_1/total_grad_norm_1.view(total_grad_1.size(0),1,1,1)
                    total_grad_dir_2 = total_grad_2/total_grad_norm_2.view(total_grad_2.size(0),1,1,1)
                total_grad_norm_rescale_1 = torch.min(total_grad_norm_1, torch.tensor(input_eps).cuda())
                total_grad_norm_rescale_2 = torch.min(total_grad_norm_2, torch.tensor(input_eps).cuda())
                if len(x_curr.size())==2:
                    clipped_grad_1 = total_grad_norm_rescale_1.view(total_grad_1.size(0),1) * total_grad_dir_1
                    clipped_grad_2 = total_grad_norm_rescale_2.view(total_grad_2.size(0),1) * total_grad_dir_2
                else:
                    clipped_grad_1 = total_grad_norm_rescale_1.view(total_grad_1.size(0),1,1,1) * total_grad_dir_1
                    clipped_grad_2 = total_grad_norm_rescale_2.view(total_grad_2.size(0),1,1,1) * total_grad_dir_2
                x_adv=x_curr.detach().clone()
                x_adv[index_1:index_1+1]+=clipped_grad_1
                x_adv[index_2:index_2+1]+=clipped_grad_2
                x_adv = torch.clamp(x_adv, clip_min, clip_max)
                x_var.data = x_adv

            end_layer_dists.append(np.sqrt(loss.item()))
            end_dists.append(torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]).item())

        # If we want to track the best possible loss
        # best_losses, best_adv_x = track_best(best_losses, best_adv_x, losses_cal, img_variable)

            diff_array = np.array(x_adv.cpu())-np.array(x_curr.data.cpu())
            diff_array = diff_array.reshape(len(diff_array),-1)

            # img_variable.data = image_tensor_orig
            print("peturbation1= {}".format(
               np.linalg.norm(diff_array[index_1:index_1+1])))
            print("peturbation2= {}".format(
               np.linalg.norm(diff_array[index_2:index_2+1])))
        start_layer_dists_all.append(start_layer_dists)
        end_layer_dists_all.append(end_layer_dists)
        end_dists_all.append(end_dists)

    result_dict['start_dists']=n_closest_distances
    result_dict['start_all_layer_dists']=start_layer_dists_all
    result_dict['end_dists']=end_dists_all
    result_dict['end_all_layer_dists']=end_layer_dists_all

    return result_dict


def modified_pgd_intersection(args,x_source,net,n_closest_distances):
    print('Computing intersections for %s closest samples using modified PGD' % args.closest_n)
    result_dict_pgd_mod={}
    end_dists_all_pgd_mod=[]
    start_layer_dists_all_pgd_mod=[]
    end_layer_dists_all_pgd_mod=[]

    input_eps=args.input_eps

    index_1=0
    index_2=1


    # PGD with mods
    n_steps=1000
    eps_step_start=(input_eps*2.5)/1000
    clip_min=0.0
    clip_max=1.0
    # y_var = Variable(y_curr, requires_grad= False)
    for k in range(net.n_outputs-1):
        print('Optimization for output type %s' % k)
        end_dists=[]
        start_layer_dists=[]
        end_layer_dists=[]
        for j in range(args.closest_n):
            print('Running for pair of examples %s' % j)
            x_curr=torch.cat((x_source[j:j+1],x_source[args.closest_n+j:args.closest_n+j+1]))
            x_var = Variable(x_curr, requires_grad= True)
            for i in range(n_steps):
                # Adaptive LR
                # if i>900:
                #     eps_step=eps_step_start/10
                # else:
                #     eps_step=
                eps_step=eps_step_start
                zero_gradients(x_var)
                output_list_1 =net.forward(x_var[index_1:index_1+1])
                output_list_2 =net.forward(x_var[index_2:index_2+1])
            #     print('l2_dists: %s' % torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]))
        #         print('l2_dists: %s' % torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]))
                loss=loss_fn(output_list_1[k+1],output_list_2[k+1])
                if i==0:
                    start_layer_dists.append(np.sqrt(loss.item()))
                loss.backward()
                
                # if loss.item()<best_loss:
                #     best_loss=loss.item()
                
                raw_grad = x_var.grad.data
                raw_grad_1 = raw_grad[index_1:index_1+1]
                raw_grad_2 = raw_grad[index_2:index_2+1]

                curr_delta=x_var.data-x_curr
                curr_delta_1=curr_delta[index_1:index_1+1]
                curr_delta_2=curr_delta[index_2:index_2+1]

                curr_delta_norm_1 = torch.max(
                       curr_delta_1.view(curr_delta_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                curr_delta_norm_2 = torch.max(
                       curr_delta_2.view(curr_delta_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

                # Direction of current perturbation
                curr_delta_dir_1 = curr_delta_1/curr_delta_norm_1.view(curr_delta_norm_1.size(0),1,1,1)
                curr_delta_dir_2 = curr_delta_2/curr_delta_norm_2.view(curr_delta_norm_2.size(0),1,1,1)

                # Projecting along current perturbation
                projection_1_mag=torch.tensordot(raw_grad_1,curr_delta_dir_1,dims=4)
                projection_2_mag=torch.tensordot(raw_grad_2,curr_delta_dir_2,dims=4)

                # Dividing into components
                grad_1_comp1=projection_1_mag*curr_delta_dir_1
                grad_1_comp2=raw_grad_1-grad_1_comp1

                grad_2_comp1=projection_2_mag*curr_delta_dir_2
                grad_2_comp2=raw_grad_2-grad_2_comp1
                
                grad_1_comp1_norm = torch.max(
                       grad_1_comp1.view(grad_1_comp1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                grad_1_comp2_norm = torch.max(
                       grad_1_comp2.view(grad_1_comp2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

                grad_2_comp1_norm = torch.max(
                       grad_2_comp1.view(grad_2_comp1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                grad_2_comp2_norm = torch.max(
                       grad_2_comp2.view(grad_2_comp2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

                if len(x_curr.size())==2:
                    grad_1_comp1_dir = grad_1_comp1/grad_1_comp1_norm.view(grad_1_comp1.size(0),1)
                    grad_1_comp2_dir = grad_1_comp2/grad_1_comp2_norm.view(grad_1_comp2.size(0),1)
                    grad_2_comp1_dir = grad_2_comp1/grad_2_comp1_norm.view(grad_2_comp1.size(0),1)
                    grad_2_comp2_dir = grad_2_comp2/grad_2_comp2_norm.view(grad_2_comp2.size(0),1)
                else:
                    grad_1_comp1_dir = grad_1_comp1/grad_1_comp1_norm.view(grad_1_comp1.size(0),1,1,1)
                    grad_1_comp2_dir = grad_1_comp2/grad_1_comp2_norm.view(grad_1_comp2.size(0),1,1,1)
                    grad_2_comp1_dir = grad_2_comp1/grad_2_comp1_norm.view(grad_2_comp1.size(0),1,1,1)
                    grad_2_comp2_dir = grad_2_comp2/grad_2_comp2_norm.view(grad_2_comp2.size(0),1,1,1)

                adv_temp_1 = x_var[index_1:index_1+1].data + -1 * eps_step * grad_1_comp1_dir + -1 * eps_step * grad_1_comp2_dir
                adv_temp_2 = x_var[index_2:index_2+1].data + -1 * eps_step * grad_2_comp1_dir + -1 * eps_step * grad_2_comp2_dir
#                 print(torch.tensordot(grad_1_comp1_dir,grad_1_comp2_dir,dims=4))
                
                # Clipping total perturbation
                total_grad_1 = adv_temp_1 - x_curr[index_1:index_1+1]
                total_grad_2 = adv_temp_2 - x_curr[index_2:index_2+1]
                total_grad_1_copy=total_grad_1.detach().clone().cpu().numpy().reshape(784)
                total_grad_2_copy=total_grad_2.detach().clone().cpu().numpy().reshape(784)
                total_grad_norm_1 = torch.max(
                       total_grad_1.view(total_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                total_grad_norm_2 = torch.max(
                       total_grad_2.view(total_grad_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
                if len(x_curr.size())==2:
                    total_grad_dir_1 = total_grad_1/total_grad_norm_1.view(total_grad_1.size(0),1)
                    total_grad_dir_2 = total_grad_2/total_grad_norm_2.view(total_grad_2.size(0),1)
                else:
                    total_grad_dir_1 = total_grad_1/total_grad_norm_1.view(total_grad_1.size(0),1,1,1)
                    total_grad_dir_2 = total_grad_2/total_grad_norm_2.view(total_grad_2.size(0),1,1,1)
                total_grad_norm_rescale_1 = torch.min(total_grad_norm_1, torch.tensor(input_eps).cuda())
                total_grad_norm_rescale_2 = torch.min(total_grad_norm_2, torch.tensor(input_eps).cuda())
                if len(x_curr.size())==2:
                    clipped_grad_1 = total_grad_norm_rescale_1.view(total_grad_1.size(0),1) * total_grad_dir_1
                    clipped_grad_2 = total_grad_norm_rescale_2.view(total_grad_2.size(0),1) * total_grad_dir_2
                else:
                    clipped_grad_1 = total_grad_norm_rescale_1.view(total_grad_1.size(0),1,1,1) * total_grad_dir_1
                    clipped_grad_2 = total_grad_norm_rescale_2.view(total_grad_2.size(0),1,1,1) * total_grad_dir_2

                x_adv=x_curr.detach().clone()

                # print('True step %s' % torch.norm(total_grad_1-curr_delta_1).item())
                # print('Clipped step %s' % torch.norm(clipped_grad_1-curr_delta_1).item())

                x_adv[index_1:index_1+1]+=clipped_grad_1
                x_adv[index_2:index_2+1]+=clipped_grad_2
                x_adv = torch.clamp(x_adv, clip_min, clip_max)
                x_var.data = x_adv

                end_layer_dists.append(np.sqrt(loss.item()))
                end_dists.append(torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]).item())

            # best_losses, best_adv_x = track_best(best_losses, best_adv_x, losses_cal, img_variable)

            diff_array = np.array(x_adv.cpu())-np.array(x_curr.data.cpu())
            diff_array = diff_array.reshape(len(diff_array),-1)

            # img_variable.data = image_tensor_orig
            print("peturbation1= {}".format(
               np.linalg.norm(diff_array[index_1:index_1+1])))
            print("peturbation2= {}".format(
               np.linalg.norm(diff_array[index_2:index_2+1])))
        start_layer_dists_all_pgd_mod.append(start_layer_dists)
        end_layer_dists_all_pgd_mod.append(end_layer_dists)
        end_dists_all_pgd_mod.append(end_dists)

    result_dict_pgd_mod['start_dists']=n_closest_distances
    result_dict_pgd_mod['start_all_layer_dists']=start_layer_dists_all_pgd_mod
    result_dict_pgd_mod['end_dists']=end_dists_all_pgd_mod
    result_dict_pgd_mod['end_all_layer_dists']=end_layer_dists_all_pgd_mod

    return result_dict_pgd_mod


def apgd_intersection(args,agent_num,x_curr,net,num_outputs,source_idx,source_dists,apgd_all_result_dict,apgd_all_perturb_dict):
    # Setting APGD params
    rho=0.75
    alpha=0.75
    n_steps=args.n_steps
    attack_eps=args.input_eps
    
    index_1=0
    index_2=1
    
    # Setting checkpoints
    p0=0.0
    p1=0.2
    p_list=[]
    p_list.append(p0)
    p_list.append(p1)
    w_curr=0
    checkpoint_list=[]
    checkpoint_list.append(w_curr)
    iter_num=0
    while w_curr<=n_steps:
        iter_num+=1
        if iter_num==1:
            p_curr=p1
            w_curr=np.ceil(p_curr*n_steps)
            checkpoint_list.append(w_curr)
        else:
            p_curr=p_list[iter_num-1]+np.maximum(p_list[iter_num-1]-p_list[iter_num-2]-0.03,0.06)
            w_curr=np.ceil(p_curr*n_steps)
            checkpoint_list.append(w_curr)
            p_list.append(p_curr)
    
    # Lists to stort data
    result_dict={}
    perturb_store={}
    end_dists_all_pgd=[]
    start_layer_dists_all_pgd=[]
    end_layer_dists_all_pgd=[]
    grad_list_1=[]
    grad_list_2=[]
    loss_list=[]

    clip_min=0.0
    clip_max=1.0
    # y_var = Variable(y_curr, requires_grad= False)
    for k in range(1,num_outputs):
        print('Optimization for output type %s for run %s' % (k,args.start_idx+agent_num))
        # Starting learning rate
        learning_rate=attack_eps/1000.0
        best_loss=100.0
        checkpoint_count=0
        loss_dec_count=0
        best_loss_replace_count=0

        # Taking first step
        x_var = Variable(x_curr.clone(), requires_grad=True)

        output_list_1 = net.forward(x_var[index_1:index_1+1])
        output_list_2 = net.forward(x_var[index_2:index_2+1])
        loss=loss_fn(output_list_1[k],output_list_2[k])
        start_layer_dists_all_pgd.append(np.sqrt(loss.item()))
        loss.backward()

        raw_grad = x_var.grad.data
        raw_grad_1 = raw_grad[index_1:index_1+1]
        raw_grad_2 = raw_grad[index_2:index_2+1]
        raw_grad_norm_1 = torch.max(
               raw_grad_1.view(raw_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

        adv_temp_1 = x_var[index_1:index_1+1].data + -1 * learning_rate * raw_grad_1
        adv_temp_2 = x_var[index_2:index_2+1].data + -1 * learning_rate * raw_grad_2

        # Clipping total perturbation
        total_grad_1 = adv_temp_1 - x_curr[index_1:index_1+1]
        total_grad_2 = adv_temp_2 - x_curr[index_2:index_2+1]
        total_grad_norm_1 = torch.max(
               total_grad_1.view(total_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
        total_grad_norm_2 = torch.max(
               total_grad_2.view(total_grad_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

        total_grad_dir_1 = total_grad_1/total_grad_norm_1.view(total_grad_1.size(0),1,1,1)
        total_grad_dir_2 = total_grad_2/total_grad_norm_2.view(total_grad_2.size(0),1,1,1)

        total_grad_norm_rescale_1 = torch.min(total_grad_norm_1, torch.tensor(attack_eps).cuda())
        total_grad_norm_rescale_2 = torch.min(total_grad_norm_2, torch.tensor(attack_eps).cuda())

        clipped_grad_1 = total_grad_norm_rescale_1.view(total_grad_1.size(0),1,1,1) * total_grad_dir_1
        clipped_grad_2 = total_grad_norm_rescale_2.view(total_grad_2.size(0),1,1,1) * total_grad_dir_2

        x_adv=x_curr.detach().clone()

        x_adv[index_1:index_1+1]+=clipped_grad_1
        x_adv[index_2:index_2+1]+=clipped_grad_2
        x_adv = torch.clamp(x_adv, clip_min, clip_max)
        x_var.data = x_adv

        output_list_1_step_0 =net.forward(x_curr[index_1:index_1+1])
        output_list_2_step_0 =net.forward(x_curr[index_2:index_2+1])
        loss_step_0=loss_fn(output_list_1[k],output_list_2[k])

        output_list_1_step_1 =net.forward(x_var[index_1:index_1+1])
        output_list_2_step_1 =net.forward(x_var[index_2:index_2+1])
        loss_step_1=loss_fn(output_list_1_step_1[k],output_list_2_step_1[k])

        loss_list.append(loss_step_0.item())
        loss_list.append(loss_step_1.item())
#         print('Loss at step 0 is %s' % (loss_step_0.item()))
#         print('Loss at step 1 is %s' % (loss_step_1.item()))

        if loss_step_0.item()<loss_step_1.item():
#             print('Perturb does not help')
            best_x=x_curr
            best_loss=loss_step_0.item()
            best_loss_prev=best_loss
            learning_rate_prev=learning_rate
        else:
#             print('Perturb helps')
            best_x=x_var
            best_loss=loss_step_1.item()
            best_loss_prev=best_loss
            learning_rate_prev=learning_rate

        x_var_prev_step=x_curr.clone()

        for i in tqdm(range(2,n_steps)):
#             if i%1000==0:
#                 print('Learning rate %s' % learning_rate)
#                 print('Loss at start of step %s is %s' % (i,loss_step_k.item()))
#                 print('No. of time best loss has been replaced: %s' % best_loss_replace_count)

            zero_gradients(x_var)
            output_list_1 =net.forward(x_var[index_1:index_1+1])
            output_list_2 =net.forward(x_var[index_2:index_2+1])
            loss=loss_fn(output_list_1[k],output_list_2[k])

            loss.backward()

            raw_grad = x_var.grad.data
            raw_grad_1 = raw_grad[index_1:index_1+1]
            raw_grad_norm_1 = torch.max(
                   raw_grad_1.view(raw_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
            raw_grad_2 = raw_grad[index_2:index_2+1]
            raw_grad_norm_2 = torch.max(
                   raw_grad_2.view(raw_grad_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
            adv_temp_1 = x_var[index_1:index_1+1].data + -1 * learning_rate * (raw_grad_1)
            adv_temp_2 = x_var[index_2:index_2+1].data + -1 * learning_rate * (raw_grad_2)

            output_list_1_raw =net.forward(adv_temp_1)
            output_list_2_raw =net.forward(adv_temp_2)
            loss_step_k_raw=loss_fn(output_list_1_raw[k],output_list_2_raw[k])

            # Clipping total perturbation
            total_grad_1 = adv_temp_1 - x_curr[index_1:index_1+1]
            total_grad_2 = adv_temp_2 - x_curr[index_2:index_2+1]
            total_grad_norm_1 = torch.max(
                   total_grad_1.view(total_grad_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
            total_grad_norm_2 = torch.max(
                   total_grad_2.view(total_grad_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

            total_grad_dir_1 = total_grad_1/total_grad_norm_1.view(total_grad_1.size(0),1,1,1)
            total_grad_dir_2 = total_grad_2/total_grad_norm_2.view(total_grad_2.size(0),1,1,1)

            total_grad_norm_rescale_1 = torch.min(total_grad_norm_1, torch.tensor(attack_eps).cuda())
            total_grad_norm_rescale_2 = torch.min(total_grad_norm_2, torch.tensor(attack_eps).cuda())

            clipped_grad_1 = total_grad_norm_rescale_1.view(total_grad_1.size(0),1,1,1) * total_grad_dir_1
            clipped_grad_2 = total_grad_norm_rescale_2.view(total_grad_2.size(0),1,1,1) * total_grad_dir_2

            x_curr_clone=x_curr.detach().clone()

            x_curr_clone[index_1:index_1+1]+=clipped_grad_1
            x_curr_clone[index_2:index_2+1]+=clipped_grad_2
            x_curr_clone = torch.clamp(x_curr_clone, clip_min, clip_max)

            output_list_1_clipped =net.forward(x_curr_clone[index_1:index_1+1])
            output_list_2_clipped =net.forward(x_curr_clone[index_2:index_2+1])
            loss_step_k_clipped=loss_fn(output_list_1_clipped[k],output_list_2_clipped[k])


            curr_delta_1_part_1=x_curr_clone[index_1:index_1+1]-x_var[index_1:index_1+1]
            curr_delta_2_part_1=x_curr_clone[index_2:index_2+1]-x_var[index_2:index_2+1]

            # Momentum part
            curr_delta_1_part_2=x_var[index_1:index_1+1]-x_var_prev_step[index_1:index_1+1]
            curr_delta_2_part_2=x_var[index_2:index_2+1]-x_var_prev_step[index_2:index_2+1]

            # Updating x_var and x_var_prev
            x_var_prev_step[index_1:index_1+1].data=x_var[index_1:index_1+1]
            x_var_prev_step[index_2:index_2+1].data=x_var[index_2:index_2+1]

            x_var.data=x_curr_clone

            # Adding perturbations
            adv_total_temp_1= x_var[index_1:index_1+1].data + 1 * alpha * curr_delta_1_part_1 + 1 * (1-alpha) * curr_delta_1_part_2
            adv_total_temp_2= x_var[index_2:index_2+1].data + 1 * alpha * curr_delta_2_part_1 + 1 * (1-alpha) * curr_delta_2_part_2

            output_list_1_raw_2 = net.forward(adv_total_temp_1)
            output_list_2_raw_2 = net.forward(adv_total_temp_2)
            loss_step_k_raw_2 =loss_fn(output_list_1_raw_2[k],output_list_2_raw_2[k])

            # Clipping for final time
            total_perturb_1 = adv_total_temp_1 - x_curr[index_1:index_1+1]
            total_perturb_2 = adv_total_temp_2 - x_curr[index_2:index_2+1]
            total_perturb_norm_1 = torch.max(
                   total_perturb_1.view(total_perturb_1.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())
            total_perturb_norm_2 = torch.max(
                   total_perturb_2.view(total_perturb_2.size(0), -1).norm(2, 1), torch.tensor(1e-9).cuda())

            total_perturb_dir_1 = total_perturb_1/total_perturb_norm_1.view(total_perturb_1.size(0),1,1,1)
            total_perturb_dir_2 = total_perturb_2/total_perturb_norm_2.view(total_perturb_2.size(0),1,1,1)

            total_perturb_norm_rescale_1 = torch.min(total_perturb_norm_1, torch.tensor(attack_eps).cuda())
            total_perturb_norm_rescale_2 = torch.min(total_perturb_norm_2, torch.tensor(attack_eps).cuda())

            clipped_perturb_1 = total_perturb_norm_rescale_1.view(total_perturb_1.size(0),1,1,1) * total_perturb_dir_1
            clipped_perturb_2 = total_perturb_norm_rescale_2.view(total_perturb_2.size(0),1,1,1) * total_perturb_dir_2

            x_curr_clone_2=x_curr.detach().clone()

#                 print('True step %s' % torch.norm(total_grad_1-curr_delta_1).item())
#                 print('Clipped step %s' % torch.norm(clipped_grad_1-curr_delta_1).item())

            x_curr_clone_2[index_1:index_1+1]+=clipped_perturb_1
            x_curr_clone_2[index_2:index_2+1]+=clipped_perturb_2
            x_curr_clone_2 = torch.clamp(x_curr_clone_2, clip_min, clip_max)

            x_var.data = x_curr_clone_2

            output_list_1_post =net.forward(x_curr_clone_2[index_1:index_1+1])
            output_list_2_post =net.forward(x_curr_clone_2[index_2:index_2+1])
            loss_step_k=loss_fn(output_list_1_post[k],output_list_2_post[k])

            loss_list.append(loss_step_k.item())
#                 print('Loss at step %s is %s' % (i,loss_step_k.item()))
            if loss_list[len(loss_list)-1]<loss_list[len(loss_list)-2]:
                loss_dec_count+=1

            if loss_step_k.item()<best_loss:
                if i%1000==0:
                    best_loss_replace_count=0
                best_loss_replace_count+=1
                best_x=x_curr_clone_2
                best_loss=loss_step_k.item()

            if i==checkpoint_list[checkpoint_count+1]:
#                 print('######')
#                 print('Reached a checkpoint')
                cond_1 = loss_dec_count<rho*(checkpoint_list[checkpoint_count+1]-checkpoint_list[checkpoint_count])
#                 print('Condition 1: %s' % cond_1)

                cond_2 = (learning_rate==learning_rate_prev) & (best_loss==best_loss_prev)
#                 print('Condition 2: %s' % cond_2)
#                 print('######')

                if cond_1 or cond_2:
#                     print('Learning rate has dropped')
                    # Decrease learning rate
                    learning_rate_prev=learning_rate
                    learning_rate/=2.0
                    # Restart at best point found so far
                    x_var.data=best_x

                loss_dec_count=0
                best_loss_prev=best_loss
                checkpoint_count+=1

            # Saving after iterations
        end_layer_dists_all_pgd.append(best_loss)
        end_dists_all_pgd.append(torch.norm(x_var[index_1:index_1+1]-x_var[index_2:index_2+1]).item())

        diff_array = np.array(x_curr_clone_2.detach().cpu())-np.array(x_curr.data.detach().cpu())
        diff_array = diff_array.reshape(len(diff_array),-1)

        curr_key=str(k)
        perturb_store[curr_key]=diff_array

        # img_variable.data = image_tensor_orig
#         print("peturbation1= {}".format(
#            np.linalg.norm(diff_array[index_1:index_1+1])))
#         print("peturbation2= {}".format(
#            np.linalg.norm(diff_array[index_2:index_2+1])))
    
    result_dict['start_dists']=source_dists
    result_dict['start_all_layer_dists']=start_layer_dists_all_pgd
    result_dict['end_dists']=end_dists_all_pgd
    result_dict['end_all_layer_dists']=end_layer_dists_all_pgd
    
    curr_key=str(args.start_idx+agent_num)+ ',(' + str(source_idx[0])+','+str(source_idx[1]) + ')'
    apgd_all_result_dict[curr_key]=result_dict
    apgd_all_perturb_dict[curr_key]=perturb_store
    
    print('Completed run %s' % (args.start_idx+agent_num))

    return