import torch
import torch.nn as nn
import os

import numpy as np
import time
import argparse
from torchsummary import summary

import pickle
import json

from multiprocessing import Process, Manager, Pool
import ctypes

from utils.mnist_models import cnn_3l, cnn_3l_bn, cnn_3l_bn_2, fc_3l_2, fc_3l
# from utils.resnet_cifar import resnet
# from utils.cifar10_models import WideResNet
from utils.test_utils import test, robust_test
from utils.data_utils import load_dataset, load_dataset_custom
from utils.io_utils import init_dirs, model_naming, test_argparse
# from utils.deeper_layer_utils import standard_pgd_intersection, linear_exact_intersection, modified_pgd_intersection, apgd_intersection


def linear_intersect_fn(i,X_c1,X_c2,A,proj_mat,data_dim,linear_perturb_mag_dict,final_linear_dist_dict):
    print('Started %s' % i)
    X_delta_curr=X_c1-X_c2[:,i].reshape((data_dim,1))
    # input_perturb_mag_mat[:,i]=0.5*np.linalg.norm(X_delta_curr,axis=0)
    T_curr=0.5*np.dot(proj_mat,X_delta_curr)
    X1_tilde_curr=X_c1-T_curr
    X2_tilde_curr=X_c2[:,i].reshape((data_dim,1))+T_curr
    linear_diff_1_curr=np.dot(A,X1_tilde_curr)
    linear_diff_2_curr=np.dot(A,X2_tilde_curr)
    linear_diff_tot=linear_diff_1_curr-linear_diff_2_curr
    linear_perturb_mag_dict[str(i)]=np.linalg.norm(T_curr,axis=0).tolist()
    final_linear_dist_dict[str(i)]=np.linalg.norm(linear_diff_tot,axis=0).tolist()
    print('Ended %s' % i)

def main():

    rng = np.random.default_rng(77)

    parser = test_argparse()

    parser.add_argument('--input_eps', type=float, default=0.0)
    parser.add_argument('--subsample', action='store_true')
    parser.add_argument('--subsample_size', type=int, default=0)
    
    args = parser.parse_args()

    assert args.n_classes==2

    if args.num_samples is None:
        args.num_samples = 'All'
    if args.drop_eps==0:
        args.drop_eps=args.epsilon

    args.eps_step = args.epsilon*args.gamma/args.attack_iter

    args.trial_num = 1
    args.class_1=3
    args.class_2=7
    model_dir_name, log_dir_name, figure_dir_name, _ = init_dirs(
        args, train=False)
    _, model_name = model_naming(args)
    print('Loading %s' % model_name)

    if 'MNIST' in args.dataset_in:
        data_dim=784

    if 'fc_3l' in args.model:
        num_outputs=6

    print('Loading two-class filtered data')

    input_dir='input_data/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/'

    x_loc=input_dir+'%s_%s_%s_X.npy' % (args.class_1,args.class_2, args.dataset_in)
    y_loc=input_dir+'%s_%s_%s_Y.npy' % (args.class_1,args.class_2, args.dataset_in)

    X_curr=np.load(x_loc)
    Y_curr=np.load(y_loc)

    X_c1=X_curr[np.where(Y_curr==0)].reshape(args.num_samples,data_dim)
    X_c2=X_curr[np.where(Y_curr==1)].reshape(args.num_samples,data_dim)

    if args.subsample:
        indices_1 = rng.integers(args.num_samples,size=args.subsample_size)
        indices_2 = rng.integers(args.num_samples, size=args.subsample_size)

        print(indices_1,indices_2)
        X_c1 = X_c1[indices_1]
        X_c2 = X_c2[indices_2]

    params_list=np.load('models/'+args.dataset_in+'_'+model_name+'.npy',allow_pickle=True)
        # for item in net.parameters():
        #     params_list.append(item)
    L1=params_list[0]

    # Precomputing projection matrix
    U, sigma, VT = np.linalg.svd(L1,full_matrices=False)
    
    V = np.transpose(VT)

    proj_mat = np.dot(V,VT)

#     print('Loading two-class filtered data')
    
    output_dir_results = 'dl_output/'+args.dataset_in+'/'+str(args.class_1)+'_'+str(args.class_2)+'/results/'
    output_dir_perturbs = 'dl_output/'+args.dataset_in+'/'+str(args.class_1)+'_'+str(args.class_2)+'/perturbs/'

    if not os.path.exists(output_dir_results):
        os.makedirs(output_dir_results)
        
    if not os.path.exists(output_dir_perturbs):
        os.makedirs(output_dir_perturbs)

    if args.subsample:
        cone_file_name = model_name + '_' + str(args.input_eps) + '_subsample' + str(args.subsample_size) + '_cone_collision_first_linear'
    else:
        cone_file_name = model_name + '_' + str(args.input_eps) + '_subsample' + str(args.num_samples) + '_' + '_cone_collision_first_linear'
    
    
    X_c1=X_c1.T
    X_c2=X_c2.T

    manager = Manager()
    linear_perturb_mag_dict=manager.dict()
    final_linear_dist_dict=manager.dict()

    time1=time.clock()
    pool = Pool(2)
    for i in range(X_c1.shape[1]):
        linear_intersect_fn(i,X_c1,X_c2,L1,proj_mat,data_dim,linear_perturb_mag_dict,final_linear_dist_dict)
        # pool.apply_async(linear_intersect_fn, args=(i,X_c1,X_c2,L1,proj_mat,data_dim,linear_perturb_store_dict,linear_perturb_mag_dict,final_linear_dist_dict))

    pool.close()
    pool.join()


    with open(output_dir_results+cone_file_name+'.json', 'w') as f1:
            json.dump(linear_perturb_mag_dict.copy(), f1, ensure_ascii=False, indent=4)

    with open(output_dir_perturbs+cone_file_name+'_final_dists.json','wb') as f2:
            pickle.dump(final_linear_dist_dict.copy(), f2)


    if args.subsample:
        linear_perturb_mag_mat=np.zeros((args.subsample_size,args.subsample_size))
        for i in range(args.subsample_size):
            linear_perturb_mag_mat[:,i] = linear_perturb_mag_dict[str(i)]
    else:
        linear_perturb_mag_mat=np.zeros((args.num_samples,args.num_samples))
        for i in range(args.num_samples):
            linear_perturb_mag_mat[:,i] = linear_perturb_mag_dict[str(i)]

    time2=time.clock()
    print(time2-time1)
    
    np.save(output_dir_results+cone_file_name+'.npy',linear_perturb_mag_mat)
    
if __name__ == "__main__":
    main()
