import torch
import torch.nn as nn
import os

import numpy as np
import time
import argparse
from torchsummary import summary

import pickle
import json

# import torch.multiprocessing as mp

from multiprocessing import Process, Manager

from utils.mnist_models import cnn_3l, cnn_3l_bn, cnn_3l_bn_2, fc_3l_2, fc_3l
# from utils.resnet_cifar import resnet
# from utils.cifar10_models import WideResNet
from utils.test_utils import test, robust_test
from utils.data_utils import load_dataset, load_dataset_custom
from utils.io_utils import init_dirs, model_naming, test_argparse
from utils.deeper_layer_utils import standard_pgd_intersection, linear_exact_intersection, modified_pgd_intersection, apgd_intersection, cone_greedy_first_relu

def flat_to_array_index(index,num_samples):
    i=int(index/num_samples)
    j=index%num_samples
    return i,j

def main():

    torch.random.manual_seed(7)

    parser = test_argparse()

    # Distance compute args
    parser.add_argument('--start_idx', type=int, default=0)
    parser.add_argument('--end_idx', type=int, default=10)
    parser.add_argument('--analysis_norm',type=str,default='l2')
    parser.add_argument('--input_eps', type=float, default=2.0)
    parser.add_argument('--n_steps', type=int, default=1000)
    parser.add_argument('--alg_type', nargs='+')
    
    args = parser.parse_args()
    
    print('Starting run %s to %s' % (args.start_idx,args.end_idx))

    if torch.cuda.is_available():
        print('CUDA enabled')
    else:
        raise ValueError('Needs a working GPU!')

    if args.num_samples is None:
        args.num_samples = 'All'
    if args.drop_eps==0:
        args.drop_eps=args.epsilon

    args.checkpoint_path = 'trained_models'

    args.eps_step = args.epsilon*args.gamma/args.attack_iter

    args.trial_num = 1
    args.class_1=3
    args.class_2=7
    model_dir_name, log_dir_name, figure_dir_name, _ = init_dirs(
        args, train=False)
    _, model_name = model_naming(args)
#     print('Loading %s' % model_dir_name)

    training_time = False

    if args.n_classes != 10:
        loader_train, loader_test, data_details = load_dataset_custom(
            args, data_dir='/data/nvme/arjun/datasets', training_time=training_time)
    else:
        loader_train, loader_test, data_details = load_dataset(
            args, data_dir='/data/nvme/arjun/datasets', training_time=training_time)

    num_channels = data_details['n_channels']

    data_dim = data_details['h_in']*data_details['w_in']*data_details['n_channels']

    if 'MNIST' in args.dataset_in:
        if 'cnn_3l_bn' in args.model:
            net = cnn_3l_bn_2(args.n_classes)
        elif 'fc_3l' in args.model:
            net = fc_3l_2(args.n_classes)

    if 'linf' in args.attack:
        args.epsilon /= 255.
        args.eps_step /= 255.

    print('Loading model')

    net.cuda()
    criterion = nn.CrossEntropyLoss(reduction='none')
    net.eval()
    ckpt_path = 'checkpoint_' + str(args.last_epoch)
    net.load_state_dict(torch.load(model_dir_name + ckpt_path))
    
    net.share_memory()
    
#     summary(net, (data_details['n_channels'],data_details['h_in'],data_details['w_in']))
    
#     print('Testing model')
    _,num_outputs=test(net,loader_test,'blah')

    print('Loading two-class filtered data')

    input_dir='input_data/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/'

    x_loc=input_dir+'%s_%s_%s_X.npy' % (args.class_1,args.class_2, args.dataset_in)
    y_loc=input_dir+'%s_%s_%s_Y.npy' % (args.class_1,args.class_2, args.dataset_in)

    X_curr=torch.from_numpy(np.load(x_loc))
    Y_curr=torch.from_numpy(np.load(y_loc))

    X_c1=X_curr[np.where(Y_curr==0)].reshape(args.num_samples,data_dim)
    X_c2=X_curr[np.where(Y_curr==1)].reshape(args.num_samples,data_dim)

    zero_indices=np.where(Y_curr==0)[0]
    one_indices=np.where(Y_curr==1)[0]


    dist_path= input_dir + str(args.class_1) + '_' + str(args.class_2) + '_' + str(args.dataset_in)+ '_dists.npy'
    if not os.path.exists(dist_path):
        import scipy.spatial.distance
        dist_mat=scipy.spatial.distance.cdist(X_c1,X_c2,metric='euclidean')
        np.save(dist_path,dist_mat)
    else:
        dist_mat=np.load(dist_path)
    flat_dist_mat=dist_mat.flatten()
    
    sort_idx_path= input_dir + str(args.class_1) + '_' + str(args.class_2) + '_' + str(args.dataset_in)+ '_sort_indices_2d.npy'
    if not os.path.exists(sort_idx_path):
        sort_indices=np.argsort(flat_dist_mat)
        close_indices=np.array(list(map(flat_to_array_index, sort_indices,np.tile(args.num_samples,len(sort_indices)))))
        np.save(sort_idx_path, close_indices)
    else:
        print('Loading sorted indices')
        close_indices=np.load(sort_idx_path)

    num_batches=int(len(X_curr)/args.batch_size)+1

    # Choosing closest points to test
    
    samples_curr_run=args.end_idx-args.start_idx
    n_closest=close_indices[args.start_idx:args.end_idx]

    n_closest_distances=[]
    for i in range(samples_curr_run):
        n_closest_distances.append(dist_mat[n_closest[i][0],n_closest[i][1]])

    if samples_curr_run>=2:
        class_0_index_0=zero_indices[n_closest[0][0]]
        class_0_index_1=zero_indices[n_closest[1][0]]

        # Creating a tensor of length 2*args.closest_n to contain the closest examples from each class
        x_source=torch.cat((X_curr[class_0_index_0:class_0_index_0+1],X_curr[class_0_index_1:class_0_index_1+1]))
        for i in range(2,samples_curr_run):
            class_0_index=zero_indices[n_closest[i][0]]
            x_source=torch.cat((x_source,X_curr[class_0_index:class_0_index+1]))
        # y_curr=Y_curr[:100].cuda()

        for i in range(samples_curr_run):
            class_1_index=one_indices[n_closest[i][1]]
            x_source=torch.cat((x_source,X_curr[class_1_index:class_1_index+1]))
    else:
        class_0_index_0=zero_indices[n_closest[0][0]]
        class_1_index_0=one_indices[n_closest[0][1]]
        x_source=torch.cat((X_curr[class_0_index_0:class_0_index_0+1],X_curr[class_1_index_0:class_1_index_0+1]))
    # y_curr=Y_curr[:100].cuda()

    output_dir_results = 'dl_output/'+args.dataset_in+'/'+str(args.class_1)+'_'+str(args.class_2)+'/results/'
    output_dir_perturbs = 'dl_output/'+args.dataset_in+'/'+str(args.class_1)+'_'+str(args.class_2)+'/perturbs/'

    if not os.path.exists(output_dir_results):
        os.makedirs(output_dir_results)
        
    if not os.path.exists(output_dir_perturbs):
        os.makedirs(output_dir_perturbs)
        
    if 'apgd' in args.alg_type:
        x_source=x_source.cuda()
        print('Computing intersections using APGD for %s' % samples_curr_run)
        manager = mp.Manager()
        apgd_all_result_dict=manager.dict()
        apgd_all_perturb_dict=manager.dict()
#         Parallel run of all agents
        pool = mp.Pool() #use all available cores, otherwise specify the number you want as an argument
        for agent_num in range(samples_curr_run):
            print('Starting process %s' % (args.start_idx+agent_num))
            n_closest_curr=n_closest[agent_num]
            n_closest_distances_curr=n_closest_distances[agent_num]
            x_curr=torch.cat((x_source[agent_num:agent_num+1],x_source[samples_curr_run+agent_num:samples_curr_run+agent_num+1])) 
            pool.apply_async(apgd_intersection, args=(args,agent_num,x_curr,net,num_outputs,n_closest_curr,n_closest_distances_curr,apgd_all_result_dict,apgd_all_perturb_dict))
        pool.close()
        pool.join()
        
#             apgd_intersection_mod(args,x_curr,net,num_outputs,n_closest_curr,n_closest_distances_curr,apgd_all_result_dict,apgd_all_perturb_dict)

        print(apgd_all_result_dict)
        apgd_file_name = model_name + '_' + str(args.input_eps) + '_' + str(args.n_steps) + '_' + str(args.start_idx) + '_' + str(args.end_idx) + '_APGD'

        with open(output_dir_results+apgd_file_name+'.json', 'w') as f1:
            json.dump(apgd_all_result_dict.copy(), f1, ensure_ascii=False, indent=4)

        with open(output_dir_perturbs+apgd_file_name+'.pkl','wb') as f2:
            pickle.dump(apgd_all_perturb_dict.copy(), f2)
    
    
if __name__ == "__main__":
    # mp.set_start_method('spawn')
    main()
