import numpy as np
import argparse
import socket
import importlib
import time
from scipy.spatial.distance import cdist
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import scipy.misc
import math
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'models'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
import provider
import pc_util
# from attack import craft_adversarial_samples
import keras
from keras import backend as K
import time
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from tensorflow.python.client import device_lib
import gc
from scipy.spatial.distance import cdist

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('--model', default='pointnet_cls', help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]')
parser.add_argument('--batch_size', type=int, default=100, help='Batch Size during training [default: 1]')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]')
parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]')
parser.add_argument('--dump_dir', default='dump', help='dump folder path [dump]')
parser.add_argument('--visu', action='store_true', help='Whether to dump image for error case [default: False]')
FLAGS = parser.parse_args()


BATCH_SIZE = FLAGS.batch_size
NUM_POINT = FLAGS.num_point
MODEL_PATH = FLAGS.model_path
GPU_INDEX = FLAGS.gpu
MODEL = importlib.import_module(FLAGS.model) # import network module
DUMP_DIR = FLAGS.dump_dir
if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR)
LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')
LOG_FOUT2 = open('./dump/attack.txt', 'w')

NUM_CLASSES = 40
SHAPE_NAMES = [line.rstrip() for line in \
    open(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/shape_names.txt'))] 

HOSTNAME = socket.gethostname()

# ModelNet40 official train/test split
TRAIN_FILES = provider.getDataFiles( \
    os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt'))
TEST_FILES = provider.getDataFiles(\
    os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'))

def log_string(out_str):
    out_str = str(out_str)
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    print(out_str)

def log_string2(out_str):
    out_str = str(out_str)
    LOG_FOUT2.write(out_str+'\n')
    LOG_FOUT2.flush()
    print(out_str)


def random_unit_vector(R):
    """
    Generates a random 3D unit vector (direction) with a uniform spherical distribution
    Algo from http://stackoverflow.com/questions/5408276/python-uniform-spherical-distribution
    :return:
    """
    phi = np.random.uniform(0,np.pi*2)
    costheta = np.random.uniform(-1,1)
    u = np.random.uniform(0,1)

    theta = np.arccos( costheta )
    r = R * np.cbrt( u )
    x = r * np.sin( theta) * np.cos( phi )
    y = r * np.sin( theta) * np.sin( phi )
    z = r * np.cos( theta )
    return np.array([x,y,z])




def fps(best_ind, images, num):
    [np.random.shuffle(b) for b in best_ind]
    # print('shuffled')
    # print(best_ind[0, :10])
    roots = np.array([images[i, best_ind[i], :] for i in range(images.shape[0])])
    # roots = np.moveaxis(roots, 1, 0)
    # print(roots.shape)
    points = roots[:, :1, :]
    # print('p')
    # print(points.shape)
    dist = np.sum((np.repeat(points[:, -1:, :], roots.shape[1]-1, axis=1) - roots[:, 1:, :])**2, axis=2)
    # log_string2(points)
    # log_string2('h')
    # log_string2(roots)
    # log_string2('r')
    for i in range(num):
        far = np.argmax(dist, axis=1) + i + 1
        # log_string2(far)
        points = np.concatenate((points, roots[np.arange(roots.shape[0]), far, :].reshape((roots.shape[0], 1, roots.shape[2]))), axis=1)
        # log_string2(points)
        # log_string2('h')
        roots[np.arange(roots.shape[0]), far, :] = roots[:, i+1, :]
        roots[:, i+1, :] = points[:, -1, :]
        # log_string2(roots)
        # log_string2('r')
        dist[np.arange(dist.shape[0]), far-i-1] = dist[:, 0]
        dist = dist[:, 1:]
        dist += np.sum((np.repeat(points[:, -1:, :], roots.shape[1]-i-2, axis=1) - roots[:, i+2:, :])**2, axis=2)
    return points




def remove_outliers_fn(x, top_k = 10, num_std = 1.0):
    dists = x[:, tf.newaxis] - x[:, :, tf.newaxis]
    dists = tf.linalg.norm(dists, axis = 3)
    
    diag = tf.eye(tf.shape(x)[1], batch_shape = [tf.shape(x)[0]])
    dists = tf.where(diag > 0.0, tf.fill(tf.shape(dists), float("inf")), dists)
    dists = tf.nn.top_k(dists * -1.0, k = top_k, sorted = False)[0] * -1.0
    
    dists = tf.reduce_mean(dists, axis = 2)
    avg, var = tf.nn.moments(dists, axes = [1], keep_dims = True)
    std = num_std * tf.sqrt(var)
    
    remove = dists > avg + std
    idx = tf.argmin(tf.to_float(remove), axis = 1)
    one_hot = tf.one_hot(idx, tf.shape(x)[1])
    replace = tf.reduce_sum(x * one_hot[:, :, tf.newaxis], axis = 1, keep_dims = True)
    x = tf.where(remove[:, :, tf.newaxis] & tf.fill(tf.shape(x), True), replace + tf.zeros_like(x), x)

    return tf.stop_gradient(x)

def craft_adversarial_samples(sess, images, labels, pred, labels_pl, input_layer, is_training, gradient, loss1, loss2):

	# Hyperparameters
	###
    start_eps = 0.025  # initial epsilon
    end_eps = 0.025 # final epsilon
    num_points = 200 # number of points
    steps = 300  # number of steps
    end_lr = 0.002 # final step-size
    start_lr = 0.1 # initial step-size
	###
    
	learning_rate = start_lr
	num_best = num_points
	warm_up = 0
	epsilon = start_eps

    images = np.array(images)
    advesarial_images = np.array(images)


    grad = sess.run(gradient, feed_dict={input_layer:images, labels_pl:labels, is_training:False})[0]
    norms = np.sqrt(np.sum(grad**2, axis=2))
    print(np.mean(np.sum(norms == 0, axis=1)))
    best_ind = np.array([np.argpartition(norms[i], -num_best)[-num_best:] for i in range(len(norms))])
    ppp = fps(best_ind, images, num_points-1)
    advesarial_images = np.concatenate((advesarial_images, ppp), axis=1)


    for i in range(steps):
        if i >= warm_up:
            epsilon -= (start_eps - end_eps) / (steps - warm_up)
            learning_rate -= (start_lr - end_lr) / (steps - warm_up)

        gr = sess.run(gradient, feed_dict={input_layer:advesarial_images, labels_pl:labels, is_training:False})[0][:, -num_points:, :]
        # advesarial_images[:, -num_points:, :] = advesarial_images[:, -num_points:, :] + learning_rate * grad
        nr = np.linalg.norm(gr, axis=2)
        ind1 = nr != 0
        ind2 = nr == 0
        dist = []
            # log_string2('gg')
        for j in range(advesarial_images.shape[0]):
            # sh0 = advesarial_images.shape[0]
            if sum(ind1[j]) != 0:
                # print(ind1[j].shape)
                # print(nr.shape)
                norm = np.array([nr[j, ind1[j]], nr[j, ind1[j]], nr[j, ind1[j]]]).transpose()
                advesarial_images[j, -num_points:][ind1[j], :] = advesarial_images[j, -num_points:][ind1[j], :] + learning_rate * gr[j, ind1[j], :] / norm
            if sum(ind2[j] != 0):
                vc = np.array([random_unit_vector(learning_rate) for _ in range(sum(ind2[j]))])
                advesarial_images[j, -num_points:][ind2[j], :] = advesarial_images[j, -num_points:][ind2[j], :] + vc
            dist = cdist(advesarial_images[j, -num_points:], images[j])
            proj = np.min(dist, axis=1) > epsilon
            proj_ind = np.argmin(dist, axis=1)
            if sum(proj) != 0:
                norm = np.sqrt(np.sum((advesarial_images[j, -num_points:][proj, :] - images[j, proj_ind[proj]])**2, axis=1))
                norm = np.array([norm, norm, norm]).transpose()
                advesarial_images[j, -num_points:][proj, :] = images[j, proj_ind[proj]] + epsilon * (advesarial_images[j, -num_points:][proj, :] - images[j, proj_ind[proj]]) / norm
        
        if i % 20 == 19:
            print(sess.run(loss1, feed_dict={input_layer:advesarial_images, labels_pl:labels, is_training:False}))
        

    print(np.max(np.min(cdist(advesarial_images[10, -num_points:], images[10]), axis=1)))
    return np.array(advesarial_images)


def evaluate(num_votes):
    log_string2(device_lib.list_local_devices())
    
    is_training = False
    # log_string('hi')
    with tf.device('/gpu:'+str(GPU_INDEX)):
        pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
        is_training_pl = tf.placeholder(tf.bool, shape=())

        # simple model
        pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl)
        loss = MODEL.get_loss(pred, labels_pl, end_points)
        
        loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=labels_pl)
        loss1 = tf.reduce_mean(loss2)
        gradient = K.gradients(loss1, pointclouds_pl)
        saver = tf.train.Saver()

        # log_string('hi')
        

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = True
    sess = tf.Session(config=config)



    # Restore variables from disk.
    saver.restore(sess, MODEL_PATH)
    # saver2.restore(sess2, MODEL_PATH)
    log_string("Model restored.")

    

    ops = {'pointclouds_pl': pointclouds_pl,
           'labels_pl': labels_pl,
           'is_training_pl': is_training_pl,
           'pred': pred,
           'loss': loss,
           'grad': gradient,
           'l1': loss1,
           'l2': loss2}

  

    eval_one_epoch(sess, ops, num_votes)

   
def eval_one_epoch(sess, ops, num_votes=1, topk=1):
    error_cnt = 0
    is_training = False
    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w')


    for fn in range(len(TEST_FILES)):
        log_string('----'+str(fn)+'----')
        current_data, current_label = provider.loadDataFile(TEST_FILES[fn])
        current_data = current_data[:,0:NUM_POINT,:]
        current_label = np.squeeze(current_label)
        
        
        rem = current_data.shape[0] % BATCH_SIZE
        current_data = np.concatenate((current_data, current_data[:BATCH_SIZE-rem]), axis=0)
        current_label = np.concatenate((current_label, current_label[:BATCH_SIZE-rem]), axis=0)
        num_batches = current_data.shape[0] // BATCH_SIZE
 
        log_string('shape  ' + str(current_data.shape))
        for batch_idx in range(num_batches):
            log_string(str(batch_idx)+'/'+str(num_batches))
            start_idx = batch_idx * BATCH_SIZE
            a = (batch_idx+1) * BATCH_SIZE
            end_idx = a
            cur_batch_size = end_idx - start_idx

   
            
            # Aggregating BEG
            batch_loss_sum = 0 # sum of losses for the batch
            batch_pred_sum = np.zeros((cur_batch_size, NUM_CLASSES)) # score for classes
            batch_pred_classes = np.zeros((cur_batch_size, NUM_CLASSES)) # 0/1 for classes
            for vote_idx in range(num_votes):
                rotated_data = provider.rotate_point_cloud_by_angle(current_data[start_idx:end_idx, :, :],
                                                  vote_idx/float(num_votes) * np.pi * 2)
                adversarial_samples = craft_adversarial_samples(sess, rotated_data, current_label[start_idx:end_idx], ops['pred'], ops['labels_pl'], ops['pointclouds_pl'], ops['is_training_pl'], ops['grad'], ops['l1'], ops['l2'])


                feed_dict = {ops['pointclouds_pl']: adversarial_samples,
                             ops['labels_pl']: current_label[start_idx:end_idx],
                             ops['is_training_pl']: is_training}
                loss_val, pred_val = sess.run([ops['loss'], ops['pred']],
                                          feed_dict=feed_dict)
                # log_string(str(pred_val))
                batch_pred_sum += pred_val
                batch_pred_val = np.argmax(pred_val, 1)
                for el_idx in range(cur_batch_size):
                    batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1
                batch_loss_sum += (loss_val * cur_batch_size / float(num_votes))
            # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1]
            # pred_val = np.argmax(batch_pred_classes, 1)
            pred_val = np.argmax(batch_pred_sum, 1)
            # Aggregating END
            
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            # correct = np.sum(pred_val_topk[:,0:topk] == label_val)
            total_correct += correct
            total_seen += cur_batch_size
            loss_sum += batch_loss_sum

            print('progress')
            print(total_correct)
            print(total_seen)
            print(loss_sum)

            for i in range(start_idx, end_idx):
                l = current_label[i]
                total_seen_class[l] += 1
                total_correct_class[l] += (pred_val[i-start_idx] == l)
                fout.write('%d, %d\n' % (pred_val[i-start_idx], l))
                
                if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP!
                    img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l],
                                                           SHAPE_NAMES[pred_val[i-start_idx]])
                    img_filename = os.path.join(DUMP_DIR, img_filename)
                    output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :]))
                    scipy.misc.imsave(img_filename, output_img)
                    error_cnt += 1
    result = []            
    log_string('eval mean loss: %f' % (loss_sum / float(total_seen)))
    log_string('eval accuracy: %f' % (total_correct / float(total_seen)))
    log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))))
    result.append(loss_sum / float(total_seen))
    result.append(total_correct / float(total_seen))
    result.append(np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))
    class_accuracies = np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)
    for i, name in enumerate(SHAPE_NAMES):
        log_string('%10s:\t%0.3f' % (name, class_accuracies[i]))
        result.append(class_accuracies[i])
    np.save('res', np.array(result))
    


if __name__=='__main__':
    with tf.Graph().as_default():
        evaluate(num_votes=1)
    LOG_FOUT.close()
