import os

import numpy as np
import time
import argparse
from torchsummary import summary

import pickle
import json

# import torch.multiprocessing as mp

from multiprocessing import Process, Manager, Pool

# from utils.mnist_models import cnn_3l, cnn_3l_bn, cnn_3l_bn_2, fc_3l_2, fc_3l
# # from utils.resnet_cifar import resnet
# # from utils.cifar10_models import WideResNet
# from utils.test_utils import test, robust_test
# from utils.data_utils import load_dataset, load_dataset_custom
from utils.io_utils import init_dirs, model_naming, test_argparse
from utils.deeper_layer_utils import standard_pgd_intersection, linear_exact_intersection, modified_pgd_intersection, apgd_intersection, cone_greedy_first_relu, cone_greedy_second_linear, cone_greedy_second_relu

def flat_to_array_index(index,num_samples):
    i=int(index/num_samples)
    j=index%num_samples
    return i,j

def split(delimiters, string, maxsplit=0):
    import re
    regexPattern = '|'.join(map(re.escape, delimiters))
    return re.split(regexPattern, string, maxsplit)

def main():

    # torch.random.manual_seed(7)
    rng = np.random.default_rng(77)

    parser = test_argparse()

    # Distance compute args
    parser.add_argument('--closest', action='store_true')
    parser.add_argument('--start_idx', type=int, default=0)
    parser.add_argument('--end_idx', type=int, default=10)
    parser.add_argument('--analysis_norm',type=str,default='l2')
    parser.add_argument('--input_eps', type=float, default=0.0)
    parser.add_argument('--n_steps', type=int, default=1000)
    parser.add_argument('--alg_type', type=str)
    parser.add_argument('--subsample', action='store_true')
    parser.add_argument('--subsample_size', type=int, default=0)
    
    args = parser.parse_args()
    
    if args.closest:
        print('Starting run %s to %s' % (args.start_idx,args.end_idx))
    else:
        print('Starting subsample %s run' % args.subsample_size)

    if args.num_samples is None:
        args.num_samples = 'All'
    if args.drop_eps==0:
        args.drop_eps=args.epsilon

    if args.subsample:
        assert args.subsample_size>0

    args.checkpoint_path = 'trained_models'

    args.eps_step = args.epsilon*args.gamma/args.attack_iter

    args.trial_num = 1
    args.class_1=3
    args.class_2=7
    model_dir_name, log_dir_name, figure_dir_name, _ = init_dirs(
        args, train=False)
    _, model_name = model_naming(args)
    print('Loading %s' % (args.dataset_in+'_'+model_name))

    if 'MNIST' in args.dataset_in:
        data_dim=784

    if 'fc_3l' in args.model:
        num_outputs=6

    print('Loading two-class filtered data')

    input_dir='input_data/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/'

    x_loc=input_dir+'%s_%s_%s_X.npy' % (args.class_1,args.class_2, args.dataset_in)
    y_loc=input_dir+'%s_%s_%s_Y.npy' % (args.class_1,args.class_2, args.dataset_in)

    X_curr=np.load(x_loc)
    Y_curr=np.load(y_loc)

    X_c1=X_curr[np.where(Y_curr==0)].reshape(args.num_samples,data_dim)
    X_c2=X_curr[np.where(Y_curr==1)].reshape(args.num_samples,data_dim)

    zero_indices=np.where(Y_curr==0)[0]
    one_indices=np.where(Y_curr==1)[0]

    
    dist_path= input_dir + str(args.class_1) + '_' + str(args.class_2) + '_' + str(args.dataset_in)+ '_dists_subsample_' + str(args.subsample_size) + '.npy'
    if not os.path.exists(dist_path):
        import scipy.spatial.distance
        dist_mat=scipy.spatial.distance.cdist(X_c1,X_c2,metric='euclidean')
        np.save(dist_path,dist_mat)
    else:
        dist_mat=np.load(dist_path)

    if args.closest:
        flat_dist_mat=dist_mat.flatten()
        
        sort_idx_path= input_dir + str(args.class_1) + '_' + str(args.class_2) + '_' + str(args.dataset_in)+ '_sort_indices_2d.npy'
        if not os.path.exists(sort_idx_path):
            sort_indices=np.argsort(flat_dist_mat)
            close_indices=np.array(list(map(flat_to_array_index, sort_indices,np.tile(args.num_samples,len(sort_indices)))))
            np.save(sort_idx_path, close_indices)
        else:
            print('Loading sorted indices')
            close_indices=np.load(sort_idx_path)

        num_batches=int(len(X_curr)/args.batch_size)+1

        # Choosing closest points to test
        
        samples_curr_run=args.end_idx-args.start_idx
        n_closest=close_indices[args.start_idx:args.end_idx]

        n_closest_distances=[]
        for i in range(samples_curr_run):
            n_closest_distances.append(dist_mat[n_closest[i][0],n_closest[i][1]])

        x_source_1=X_c1[close_indices[args.start_idx:args.end_idx,0]]
        x_source_2=X_c2[close_indices[args.start_idx:args.end_idx,1]]

        x_source=np.vstack((x_source_1,x_source_2))
    else:
        indices_1 = rng.integers(args.num_samples,size=args.subsample_size)
        indices_2 = rng.integers(args.num_samples, size=args.subsample_size)

        x_source_1 = X_c1[indices_1]
        x_source_2 = X_c2[indices_2]

        x_source=np.vstack((x_source_1,x_source_2))

        samples_curr_run=args.subsample_size*args.subsample_size
    # y_curr=Y_curr[:100].cuda()

    output_dir_results = 'dl_output/'+args.dataset_in+'/'+str(args.class_1)+'_'+str(args.class_2)+'/results/'
    output_dir_perturbs = 'dl_output/'+args.dataset_in+'/'+str(args.class_1)+'_'+str(args.class_2)+'/perturbs/'

    if not os.path.exists(output_dir_results):
        os.makedirs(output_dir_results)
        
    if not os.path.exists(output_dir_perturbs):
        os.makedirs(output_dir_perturbs)
        
    # Cone-program based distance computation

    print(args.alg_type)
    if 'cone_collision' in args.alg_type:
        params_list=np.load('models/'+args.dataset_in+'_'+model_name+'.npy',allow_pickle=True)
        # for item in net.parameters():
        #     params_list.append(item)
        L1=params_list[0]
        k1=params_list[1]
        L2=params_list[2]
        k2=params_list[3]

        print('Using greedy method for collision finding for %s samples' % samples_curr_run)

        manager = Manager()

        cone_all_result_dict=manager.dict()
        cone_all_perturb_dict=manager.dict()
# #         Parallel run of all agents
        pool = Pool(30) #use all available cores, otherwise specify the number you want as an argument
#         # processes=[]

        for agent_num in range(samples_curr_run):
            if args.closest:
                print('Starting process %s' % (args.start_idx+agent_num))
                source_idx=n_closest[agent_num]
                source_dists=n_closest_distances[agent_num]
                v1=x_source[agent_num:agent_num+1].reshape(data_dim)
                v2=x_source[samples_curr_run+agent_num:samples_curr_run+agent_num+1].reshape(data_dim)
            else:
                v1_idx=int(agent_num/args.subsample_size)
                v2_idx=int(agent_num%args.subsample_size)
                print(samples_curr_run,v1_idx,v2_idx)
                v1=x_source[v1_idx:v1_idx+1].reshape(data_dim)
                v2=x_source[args.subsample_size+v2_idx:args.subsample_size+v2_idx+1].reshape(data_dim)
                source_idx=(indices_1[v1_idx],indices_2[v2_idx])
                source_dists=dist_mat[indices_1[v1_idx],indices_2[v2_idx]]
            if 'first_relu' in args.alg_type:
                # cone_greedy_first_relu(args,agent_num,num_outputs,v1,v2,L1,k1,data_dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict)
                pool.apply_async(cone_greedy_first_relu, args=(args,agent_num,num_outputs,v1,v2,L1,k1,data_dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict))
            elif 'second_linear' in args.alg_type:
                pool.apply_async(cone_greedy_second_linear, args=(args,agent_num,num_outputs,v1,v2,L1,k1,L2,k2,data_dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict))
                # cone_greedy_second_linear(args,agent_num,num_outputs,v1,v2,L1,k1,L2,k2,data_dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict)
            elif 'second_relu' in args.alg_type:
                pool.apply_async(cone_greedy_second_relu, args=(args,agent_num,num_outputs,v1,v2,L1,k1,L2,k2,data_dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict))
                # cone_greedy_second_relu(args,agent_num,num_outputs,v1,v2,L1,k1,L2,k2,data_dim,source_idx,source_dists,cone_all_result_dict,cone_all_perturb_dict)
        pool.close()
        pool.join()

        if args.closest:
            cone_file_name = model_name + '_' + str(args.input_eps) + '_' + str(args.start_idx) + '_' + str(args.end_idx) + '_' + args.alg_type
        else:
            cone_file_name = model_name + '_' + str(args.input_eps) + '_subsample' + str(args.subsample_size) + '_' + args.alg_type

        with open(output_dir_results+cone_file_name+'.json', 'w') as f1:
            json.dump(cone_all_result_dict.copy(), f1, ensure_ascii=False, indent=4)

        with open(output_dir_perturbs+cone_file_name+'.pkl','wb') as f2:
            pickle.dump(cone_all_perturb_dict.copy(), f2)


    # Dumping to matrix
    if args.closest:
        curr_dist_mat=np.zeros((5000,5000))
    else:
        curr_dist_mat=np.zeros((args.subsample_size,args.subsample_size))

    delimiters="(", ",", ")"

    for k,v in cone_all_result_dict.copy().items():
        curr_split=split(delimiters,k)
        curr_idx=int(curr_split[0])
        real_idx1=int(curr_split[2])
        real_idx2=int(curr_split[3])
        if args.closest:
            curr_dist_mat[idx1,idx2]=v['perturb_magnitudes']
        else:
            curr_dist_mat[int(curr_idx/args.subsample_size),int(curr_idx%args.subsample_size)]=v['perturb_magnitudes']

    print('Dumping to matrix')
    np.save(output_dir_results+cone_file_name+'.npy',curr_dist_mat)            
    
    if 'exact' in args.alg_type:
        print('Computing exact intersections for %s after first linear layer' % samples_curr_run)
        
        params_list=[]
        for item in net.parameters():
            params_list.append(item)

        A=params_list[0].detach().cpu().numpy()

        # Precomputing projection matrix
        U, sigma, VT = np.linalg.svd(A,full_matrices=False)
        
#         print(U.shape)
#         print(VT.shape)
#         print(sigma.shape)
        
#         print(np.diag(sigma))

        V = np.transpose(VT)
        
        proj_mat = np.dot(V,VT)
        
        manager = mp.Manager()
        linear_all_result_dict=manager.dict()
        linear_all_perturb_dict=manager.dict()
#         Parallel run of all agents
        pool = mp.Pool() #use all available cores, otherwise specify the number you want as an argument
        
        
        time1=time.clock()
        for agent_num in range(samples_curr_run):
            print('Starting process %s' % (args.start_idx+agent_num))
            n_closest_curr=n_closest[agent_num]
            n_closest_distances_curr=n_closest_distances[agent_num]
            x_curr=torch.cat((x_source[agent_num:agent_num+1],x_source[samples_curr_run+agent_num:samples_curr_run+agent_num+1]))
            linear_exact_intersection(args,A,proj_mat,agent_num,x_curr,net,num_outputs,n_closest_curr,n_closest_distances_curr,linear_all_result_dict,linear_all_perturb_dict)
#             pool.apply_async(linear_exact_intersection, args=(args,A,proj_mat,agent_num,x_curr,net,num_outputs,n_closest_curr,n_closest_distances_curr,linear_all_result_dict,linear_all_perturb_dict))
        
        pool.close()
        pool.join()
        
        time2=time.clock()
        time_spent=(time2-time1)
        print('Avg. time spent for linear exact solution for %s pairs is %s' % (samples_curr_run,(time_spent/samples_curr_run)))
        
        linear_file_name = model_name + '_' + str(args.start_idx) + '_' + str(args.end_idx) + '_linear'
        
        with open(output_dir_results+linear_file_name+'.json', 'w') as f3:
            json.dump(linear_all_result_dict.copy(), f3, ensure_ascii=False, indent=4)

        with open(output_dir_perturbs+linear_file_name+'.pkl','wb') as f4:
            pickle.dump(linear_all_perturb_dict.copy(), f4)
    
    
if __name__ == "__main__":
    # mp.set_start_method('spawn')
    main()
