from jax import numpy as np
from jax import jit, grad, random # for compiling functions for speedup
import numpy as onp
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import copy
import math
import pickle
import os
import scipy.io
import csv
import argparse
from jax.tree_util import tree_flatten

import numpy.matlib
from scipy import linalg
from scipy.spatial import ConvexHull
import scipy.optimize
import random
from scipy.optimize import minimize


global_tolerance = 1e-6

import keras

from math import log, sqrt
from time import time
from pprint import pprint

from hyperopt import hp
from hyperopt.pyll.stochastic import sample

import sys
import os
import time
import datetime

import scipy.sparse as sparse
import osqp
import heapq
from numpy.linalg import norm
from qpsolvers import solve_qp

from scipy.optimize import linprog

class CatchPrint:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


def ac_hessian(a, b, x):
    """Return log-barrier Hessian matrix at x."""
    eps_num_stabil = 1e-5
    d = (b - a.dot(x) + eps_num_stabil)
    s = d ** -2.0
    return a.T.dot(onp.diag(s)).dot(a)

def jac(a, b, x):
    """Return log-barrier grad vector at x."""
    d = np.sum(np.reciprocal(b - a.dot(x)))

def local_norm(h, v):
    """Return the local norm of v based on the given Hessian matrix."""
    return v.T.dot(h).dot(v)

def ellipsoid_axes(e):
    """Return matrix with columns that are the axes of the ellipsoid."""
    w, v = np.linalg.eigh(e)
    return v.dot(np.diag(w**(-1/2.0)))

def chebyshev_center(a, b):
    """Return Chebyshev center of the convex polytope."""
    norm_vector = np.reshape(np.linalg.norm(a, axis=1), (a.shape[0], 1))
    c = np.zeros(a.shape[1] + 1)
    c[-1] = -1
    a_lp = np.hstack((a, norm_vector))
    res = linprog(c, A_ub=a_lp, b_ub=b, bounds=(None, None))
    if not res.success:
        raise Exception('Unable to find Chebyshev center')

    return res.x[:-1]

def dikin(A, b, x, r):
    h_x = r*hessian(A, b, ac)

    return ellipsoid_axes(h_x)

class reduceConstr():
    """
    initial_size: The initial size of output convexhull (-1 means #nodes / 20)
    convexhull_size: The max size of output convexhull (-1 means #nodes / 10)
    convergence_distance: The criteria of convergence
    tolerated_qp_error: How much distance error in QP solver is tolerated
    convergence_change_rate: The criteria of convergence for KCHA
    sigma: Sigma, the parameter of Gaussion kernel
    """
    def __init__(self, A, b, args):  # _TODO_
        self.A = A
        self.b = b
        self.args = args

    def linConstrInteriorScore(self, point):  # score = max(A*x - b) is required to be negative for x to be in the interior of the feasible region
        A = self.A
        b = self.b
        score = onp.max(onp.append((A @ point.reshape(A.shape[1], 1) - b).reshape(b.size, 1), [0]))  #np.max([0, np.max(A @ point - b)])  # linear constraint Ax - b < 0 # np.max(np.max(A @ point - b))  #
        return score

    def getInteriorPoint(self):
        A = self.A
        b = self.b
        c = linalg.lstsq(A, b)[0]  # least squares solution to Ax = b expected to be atleast in the vicinity of the interior of the feasible region
        if (onp.max(list((A @ c - b).reshape(b.size, ))) > 0):
            copt = minimize(self.linConstrInteriorScore, x0=c, method='nelder-mead',
                            options={'xtol': 1e-8, 'disp': True}).x.reshape(A.shape[1], 1)  # 'inner-most interior point using Nelder-Mead method
            print('Least squares solution is not an interior point of the feasible region.')
            print('Interior score max(0, vec(A@copt-b)) is ' + str(onp.max(A @ copt - b)) + ' at point ' + str(copt))
        else:
            copt = c
            print('Least squares solution is an interior point of the feasible region.')
            if (onp.max(list((A @ copt - b).reshape(b.size, ))) > 0):
                print(
                    'Could not find an interior point of the feasible region using Nelder-Mead method. Try something else.')
                raise SystemExit
        return copt

    def nonredundantCstr(self, copt):
        A = self.A
        b = self.b
        b_translated = b - (A @ copt).reshape(b.shape[0], 1)  # scaling to make copt the origin
        D = A / b_translated  # element wise division to obtain the normalized form D*x + ones < 0. Note that points in the matrix D are the dual vertices corresponding to primal facets/constraint planes
        try:
            self.args['initial_size'] = len(D) // 20
            self.args['convexhull_size'] = len(D) // 10
            convexhull_indexes = list(self.get_convexhull(D))
            #hull = ConvexHull(D, qhull_options="QJ")  # dual of the primal feasible region polytope obtained as the convex hull of dual points of constraints
            #convexhull_indexes = list(onp.unique(hull.simplices.flat))
            hull = D[convexhull_indexes][:]
        except Exception as e:
            print('Could not obtain dual polytope of primal feasible region using the convex hull approach.')
            print(e)
            return None, list(range(D.shape[0])), None, None, None
            #raise SystemExit
        #hull_ind = onp.unique(hull.simplices.flat)
        hull_ind = convexhull_indexes
        DD_transpose = D @ D.transpose()
        A_red = onp.zeros((len(hull_ind), self.args['dim']))# 2 784
        b_red = onp.zeros((len(hull_ind), 1))
        for red_ind in range(len(hull_ind)):
            A_red[red_ind][:] = A[red_ind][:]
            b_red[red_ind][:] = b[red_ind][:]
        return D, hull_ind, hull, A_red, b_red

    def init_convexhull(self, node_vectors):
        N, D = node_vectors.shape
        X = onp.hstack((node_vectors, onp.ones((N, 1)) * math.sqrt(D)))  # N * (D + 1)
        C = onp.ones((N, N)) / N  # N * N
        X_i_minus_X_j = X.reshape((-1, 1, D + 1)) - X.reshape((1, -1, D + 1))  # N * N * (D + 1)
        square_norm_X_i_minus_X_j = (X_i_minus_X_j * X_i_minus_X_j).sum(axis=2)  # N * N
        K = onp.exp(-square_norm_X_i_minus_X_j / (self.args['sigma'] * self.args['sigma']))  # N * N
        while True:
            C_last = C
            C = C_last * onp.sqrt(K / K.dot(C_last).clip(min=1e-100))
            if onp.abs((C - C_last)).sum() / onp.abs(C_last).sum() <= self.args['convergence_change_rate']:
                break
        C_diag = C[onp.diag_indices_from(C)]
        convexhull_indexes = set(onp.argsort(C_diag)[-self.args['initial_size']:])
        convexhull_indexes.update(node_vectors.argmax(axis=0))
        convexhull_indexes.update(node_vectors.argmin(axis=0))

        return convexhull_indexes

    def distance_point_to_convexhull_all_vector(self, z, X):
        S, d = X.shape
        global P, q, A

        # minimize (1/2)x^TPx + q^Tx; subject to Gx <= h & Ax = b
        P = sparse.csc_matrix(X.dot(X.T))
        q = -(X * z).sum(axis=1)
        G = sparse.csc_matrix(-sparse.eye(S))
        h = onp.zeros(S)
        A = onp.ones(S)
        b = 1
        a = solve_qp(P, q, G, h, A, b, solver='osqp')

        distance = z - a.dot(X)
        distance = math.sqrt(distance.dot(distance))
        return distance


    def get_distance(self, row_node, convexhull, node_vectors):
        convexhull_list = list(convexhull)
        X = node_vectors[convexhull_list]  # S*d
        z = node_vectors[row_node]  # d
        return self.distance_point_to_convexhull_all_vector(z, X)


    def find_next_node_to_convexhull(self, remaining_nodes, convexhull, node_vectors):
        global heap, distance, row, col
        remaining_nodes = list(remaining_nodes)
        heap = [(0, i, -1) for i in range(len(remaining_nodes))]
        distance_matrix = [[] for i in range(len(remaining_nodes))]
        while True:
            last_distance, col, row = heapq.heappop(heap)
            if row == len(remaining_nodes) - 1:
                break
            convexhull.add(remaining_nodes[col])
            if col == row + 1:
                distance = 0
            else:
                distance = self.get_distance(remaining_nodes[row + 1], convexhull, node_vectors)
            convexhull.remove(remaining_nodes[col])
            distance_matrix[col].append(distance)
            heapq.heappush(heap, (max(last_distance, distance), col, row + 1))
        nodes_to_be_removed = {remaining_nodes[idx] for idx, dis in enumerate(distance_matrix[col]) if abs(dis) <= self.args['tolerated_qp_error']}
        return remaining_nodes[col], last_distance, nodes_to_be_removed


    def clean_remaining_nodes(self, remaining_nodes, convexhull, node_vectors):
        for node in list(remaining_nodes):
            distance = self.get_distance(node, convexhull, node_vectors)
            if abs(distance) <= self.args['tolerated_qp_error']:
                remaining_nodes.remove(node)


    def clean_convexhull(self, convexhull, node_vectors):
        for node in list(convexhull):
            convexhull.remove(node)
            distance = self.get_distance(node, convexhull, node_vectors)
            if abs(distance) >= self.args['tolerated_qp_error']:
                convexhull.add(node)


    def get_convexhull(self, node_vectors):
        convexhull = self.init_convexhull(node_vectors)
        remaining_nodes = set(range(node_vectors.shape[0])) - convexhull
        self.clean_convexhull(convexhull, node_vectors)
        self.clean_remaining_nodes(remaining_nodes, convexhull, node_vectors)
        while len(convexhull) <= self.args['convexhull_size']:
            if len(remaining_nodes) == 0:
                #print('End with no remaining nodes')
                break
            node, distance, nodes_in_convexhull = self.find_next_node_to_convexhull(remaining_nodes, convexhull, node_vectors)
            remaining_nodes -= set(nodes_in_convexhull)
            convexhull.add(node)
            self.clean_convexhull(convexhull, node_vectors)
            if distance <= self.args['convergence_distance']:
                #print('End with distance reaching threashold', distance)
                break
            #print('%d nodes in convexhull, and %d nodes remains, and the farest distance is %f' % (len(convexhull), len(remaining_nodes), distance))
        else:
            pass
            #print('End with reaching the setted max convexhull size, %d nodes remains' % len(remaining_nodes))
        return convexhull

    def x1x2vectorsLine2d(self, xvecinput, a1, a2, b):  # function for plotting 2d lines by returning vector of x2 or x1 values given vector of values in an interval of x1 or x2
        if (onp.abs(a2) < 10**-4) and (onp.abs(a1) > 10**-4):  # if a2 \approx 0
            x1vec = b * onp.ones(xvecinput.shape) / a1
            x2vec = xvecinput
        elif (onp.abs(a2) > 10**-4) and (onp.abs(a1) < 10**-4):  # if a1 \approx 0
            x1vec = xvecinput
            x2vec = b * onp.ones(xvecinput.shape) / a2
        else:  # both a1, a2 are not \approx 0
            x1vec = xvecinput
            x2vec = (b * onp.ones(xvecinput.shape) - a1 * xvecinput) / a2
        return x1vec, x2vec

def compute_redun(A, b, copt, args, pl = False):
    constrReduction = reduceConstr(A, b, args)
    D, hull_ind, hull, A_red, b_red = constrReduction.nonredundantCstr(copt)
    err = True
    if D is None:
        err = False
    p = None
    if pl:
        x1vec = onp.linspace(-1, 4, 100)
        plt.plot(copt[0], copt[1], 'k+', markersize=5, label='interior point')
        #plt.plot(D[:, 0], D[:, 1], 'ko', markerSize=5, label='dual vertices (dual space)')
        #plt.plot(D[hull.vertices, 0], D[hull.vertices, 1], 'ro--', label='convex hull (dual space)')
        for i in range(A.shape[0]):
            x1, x2 = constrReduction.x1x2vectorsLine2d(x1vec, A[i][0], A[i][1], b[i])
            if (i == min(hull_ind)):
                label = 'original constraints'
                #plt.plot(x1, x2, 'b', label=label, linewidth=2)
            else:
                label = '_no_label_'
            #plt.plot(x1, x2, 'b', label=label, linewidth=0.1)
            if i in hull_ind:
                if (i==min(hull_ind)):
                    label = 'non-redundant constraints'
                else:
                    label = '_no_label_'
                plt.plot(x1, x2, 'r--', linewidth=1, label=label)
        plt.grid()
        plt.title(str(len(hull_ind)) + ' non-redundant constraints')
        plt.legend()
        plt.xlabel('x1')
        plt.ylabel('x2')
    p=plt
        
    return hull_ind, err, p

###################
"""
def parse_argument():
    #parser = argparse.ArgumentParser()
    #args = parser.parse_args(args=['--init_size', '-1', 
    #                              '--size', '-1',
    #                             '--convergence', '0.06',
    #                             '--error','0.03',
    #                             '--convergence','0.01', 
    #                             '--sigma','0.3'])
    args = {'initial_size':-1,
            'convexhull_size':-1,
            'convergence_distance':600, # 0.06
            'tolerated_qp_error':300, # 0.03
            'convergence_change_rate':100, # 0.01
            'sigma':3000} # 0.3
    return args
args = parse_argument()
"""
###############

# handle floats which should be integers
# works with flat params
def handle_integers( params ):
    new_params = {}
    for k, v in params.items():
        if type( v ) == float and int( v ) == v:
            new_params[k] = int( v )
        else:
            new_params[k] = v
    return new_params

def load_fmnist():
    (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
    x_train = x_train.reshape(x_train.shape[0], -1) / 255
    x_test = x_test.reshape(x_test.shape[0], -1) / 255
    y_train = keras.utils.to_categorical(y_train, num_classes=10)
    y_test = keras.utils.to_categorical(y_test, num_classes=10)
    return x_train, x_test, y_train, y_test

def l1_norm(tree):
    """Compute the l1 norm of a pytree of arrays. Useful for weight decay."""
    leaves, _ = tree_flatten(tree)
    return np.sqrt(sum(np.sum(np.abs(x)) for x in leaves))

"""
will have to fix this at some point
"""
@jit
def act2prof(act):
    posprof = np.multiply(act, act>=0)
    prof = np.multiply(np.ones_like(posprof), (posprof > 0.0))
    return prof

def same_polytope(params, batch, niters=40, eps=0.1, alpha=0.005):
  X_tst, y_tst = batch
  X_pgd = X_tst
  alpha = eps/niters * 2

  def chd(p1, p2):
    return np.sum(np.abs(p1-p2))

  def act2prof(act):
    prof = np.multiply(np.ones_like(act), (act > 0.0))
    return prof

  for i in range(niters): 
    pert = grad_in((X_pgd,y_tst),params)
    X_pgd = X_pgd + alpha * np.sign(pert[0])
    X_pgd = np.clip(X_pgd, a_min=X_tst - eps, a_max=X_tst + eps)
    
  inp, act = net_walk(params, X_tst)
  advinp,advact = net_walk(params, X_pgd)
    
  prof = act2prof(act[0])
  unique_rows = onp.unique(prof, axis=0)
  advprof = act2prof(advact[0])
  
  #hamming = vmap(chd)(prof,advprof)
  #return hamming, nunique
  return unique_rows

class Hparams:
    parser = argparse.ArgumentParser()

    def str_to_bool(value):
        if isinstance(value, bool):
            return value
        if value.lower() in {'false', 'f', '0', 'no', 'n'}:
            return False
        elif value.lower() in {'true', 't', '1', 'yes', 'y'}:
            return True
        raise ValueError(f'{value} is not a valid boolean value')

    parser.add_argument('--dataset', default='mnist', type=str, help="dataset")
    parser.add_argument('--model', default='fc1', type=str, help="model")
    # train
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--num_epochs', default=100, type=int)
    parser.add_argument('--lr', default=0.00001, type=float, help="learning rate")

    # objective
    parser.add_argument('--lm', default=1.0, type=float, help="manifold regularizer weight")
    parser.add_argument('--lc', default=0.001, type=float, help="centroid regularizer weight")
    parser.add_argument('--lh', default=0.1, type=float, help="hamming regularizer weight")
    parser.add_argument('--l2', default=1e-4, type=float, help="l2 regularizer weight")
    parser.add_argument('--l1', default=1e-5, type=float, help="l1 regularizer weight")
    parser.add_argument('--lm_warmup', default=10, type=int, help="manifold regularizer warmup")
    parser.add_argument('--lc_warmup', default=10, type=int, help="centroid regularizer warmup")
    
    parser.add_argument('--attack', default='pgd', type=str, help="attack")
    parser.add_argument('--eps', default=0.1, type=float, help="epsilon")
    parser.add_argument('--amin', default=0.0, type=float, help="clip min")
    parser.add_argument('--amax', default=1.0, type=float, help="clip max")
    parser.add_argument('--log_name', default='', type=str, help="logname")
    parser.add_argument('--logdir', default='log/', type=str, help="logdir")
    parser.add_argument('--verbose', default=1, type=int, help="verbosity")
    parser.add_argument('--log', default=True, type=int, help="logging")
    parser.add_argument('--step_type', default='reg', type=str, help="step type")
    parser.add_argument('--rec_acts', default=False, type=str_to_bool, help="record activations")
    
    def __init__(self):
        pass
    
    def string(hps):
        # We can't take all hps for file names, so we select the most important ones
        hyperparam_str = ("dataset={} model={} step={} lr={} m_lmbd={} ac_lmbd={} l2_lmbd={} lm_warmup={} lc_warmup={}"). \
                            format(hps.dataset, hps.model, hps.step_type, hps.lr, hps.lm, hps.lc, hps.l2, hps.lm_warmup, hps.lc_warmup)
        return hyperparam_str
    
    def save():
        pass
    
    def load():
        pass
    
def make_plt(fname=None,hps=None,save=False):
    if fname is None:
        record = pd.read_csv(hps.logdir+Hparams.string(hps) +'.log')
    else:
        record = pd.read_csv(fname)
    plt.rcParams['figure.figsize'] = 15, 10

    fig, ax = plt.subplots(3, 2)
    plt.xticks(np.arange(0, 101, 5.0))

    ax[0, 0].set_title('cumulative loss')
    ax[0, 0].plot(onp.log(record['train_loss']),label='train')
    ax[0, 0].plot(onp.log(record['valid_loss']),label='validation')
    ax[0, 0].plot(onp.log(record['test_loss']),label='test')
    ax[0, 0].legend()
    ax[0, 0].set_xticklabels(np.arange(0, 101, 5.0),rotation=45)
    ax[0, 0].grid()

    ax[0, 1].set_title('accuracy')
    ax[0, 1].plot(record['train_acc'],label='train')
    ax[0, 1].plot(record['valid_acc'],label='validation')
    ax[0, 1].plot(record['test_acc'],label='test')
    ax[0, 1].plot(record['valid_adv_acc'],label='valid adversarial')
    ax[0, 1].plot(record['test_adv_acc'],label='test adversarial')
    ax[0, 1].legend()
    ax[0, 1].set_xticklabels(np.arange(0, 101, 5.0),rotation=45)
    ax[0, 1].grid()

    ax[1, 0].set_title('cross entropy loss')
    ax[1, 0].plot(record['ce_loss'],label='train')
    ax[1, 0].legend()
    ax[1, 0].set_xticklabels(np.arange(0, 101, 5.0),rotation=45)
    ax[1, 0].grid()

    ax[1, 1].set_title('l2 loss')
    ax[1, 1].plot(record['l2_loss'],label='train') 
    ax[1, 1].legend()
    ax[1, 1].set_xticklabels(np.arange(0, 101, 5.0),rotation=45)
    ax[1, 1].grid()

    ax[2, 0].set_title('centroid loss')
    ax[2, 0].plot(record['ac_loss'],label='train') 
    ax[2, 0].legend()
    ax[2, 0].set_xticklabels(np.arange(0, 101, 5.0),rotation=45)
    ax[2, 0].grid()

    ax[2, 1].set_title('manifold loss')
    ax[2, 1].plot(record['mani_loss'],label='train') 
    ax[2, 1].legend()
    ax[2, 1].set_xticklabels(np.arange(0, 101, 5.0),rotation=45)
    ax[2, 1].grid()

    if save:
        plt.savefig(hps.logdir+Hparams.string(hps)+'.png')
    else:
        plt.subplots_adjust(hspace=0.28)
        plt.show()

class Logger:
    def __init__(self):
        self.lst_this_run = []
        self.lst_whole_exp = []
     
        self.history = {}
        
        self.history['train_loss'] = []
        self.history['valid_loss'] = []
        self.history['test_loss'] = []
        
        self.history['train_acc'] = []
        self.history['valid_acc'] = []
        self.history['test_acc'] = []
        
        self.history['test_adv_acc'] = []
        self.history['valid_adv_acc'] = []
        
        self.history['ce_loss'] = []
        self.history['l1_loss'] = []
        self.history['l2_loss'] = []
        self.history['ac_loss'] = []
        self.history['mani_loss'] = []
        
        #self.history['activations'] = []
        
    def update(self, data):
        self.history['train_loss'].append(data[0])
        self.history['train_acc'].append(data[1])
        
        self.history['valid_loss'].append(data[2])
        self.history['valid_acc'].append(data[3])
        
        self.history['test_loss'].append(data[4])
        self.history['test_acc'].append(data[5])
        
        self.history['valid_adv_acc'].append(data[6])
        self.history['test_adv_acc'].append(data[7])
        
        self.history['ce_loss'].append(data[-1][0])
        self.history['l1_loss'].append(data[-1][1])
        self.history['l2_loss'].append(data[-1][2])
        self.history['ac_loss'].append(data[-1][3])
        self.history['mani_loss'].append(data[-1][4])

    def add(self, string):
        self.lst_this_run.append(string)
        print(string)

    def clear(self):
        self.lst_this_run = []

    def to_file(self, folder, this_run_file):
        if not os.path.exists(folder):
            os.makedirs(folder)
        #with open(folder + this_run_file+'.log', 'w') as f:
        #    f.write('\n'.join(self.lst_this_run))
        with open(folder + this_run_file+'.log', "w") as f:
            writer = csv.writer(f)
            writer.writerow(self.history.keys())
            writer.writerows(zip(*self.history.values()))

##########################################################################
#                                                                        #
#                   PLOT  UTILITIES                                      #
#                                                                        #
##########################################################################
                
def plot2d(params, plot_min, plot_max, max_prob=False, show_data=True, cm='viridis'):
    n_grid = 200
    x_plot = onp.linspace(plot_min, plot_max, n_grid)
    y_plot = onp.linspace(plot_min, plot_max, n_grid)
    plt.figure(figsize=(10,10))

    
    points = []
    for xx in x_plot:
        for yy in y_plot:
            points.append((yy, xx))
    points = onp.array(points)
    
    logits = net_apply(params,points)
    probs = stax.softmax(logits)
    if max_prob:
        z_plot = probs.max(1)
    else:
        z_plot = probs[:, 0]
    z_plot = z_plot.reshape(len(x_plot), len(y_plot)) * 100
    
    ax = plt.gca()
    
    vmax = 100
    vmin = 50 if max_prob else 0
    plt.contourf(x_plot, y_plot, z_plot, levels=onp.linspace(50, 100, 50),cmap=cm)

    #y_onp = onp.array(one_hot_target)
    #X0 = X_tr.data.numpy()[y_onp.argmax(1)==0,:]
    #X1 = X_tr.data.numpy()[y_onp.argmax(1)==1,:]
    
    #if show_data:
    #  plt.scatter(X0[:, 0], X0[:, 1], s=20, edgecolors='red', facecolor='None',
    #              marker='o', linewidths=0.2)
    #  plt.scatter(X1[:, 0], X1[:, 1], s=20, edgecolors='green', facecolor='None',
    #              marker='s', linewidths=0.2)
    plt.xlim([plot_min, plot_max])
    plt.ylim([plot_min, plot_max])
    
    return plt

def get_spaced_colors(n):
    """Given number, n, returns n colors which are visually well distributed
    """
    max_value = 255**3
    interval = int(max_value / n)
    colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)]

    return [(int(i[:2], 16) / 255.0, int(i[2:4], 16) / 255.0, int(i[4:], 16) / 255.0, 1) for i in colors]


def get_color_dictionary(list):
    """Creates a dictionary of evenly spaced colors, keys are elements in provided lists
    """
    n = len(list)
    colors = get_spaced_colors(n)
    color_dict = {}

    for element, color in zip(list, colors):
        color_dict[element] = color

    return color_dict

def fuzzy_equal(x, y, tolerance=global_tolerance):
    """ Fuzzy float equality check. Returns true if x,y are within tolerance
        x, y are scalars
    """
    return abs(x - y) < tolerance


def fuzzy_vector_equal(x_vec, y_vec, tolerance=global_tolerance):
    """ Same as above, but for vectors.
        x_vec, y_vec are 1d numpy arrays
     """
    return all(abs(el) < tolerance for el in x_vec - y_vec)

def fuzzy_vector_equal_plus(x_vec, y_vec, tolerance=global_tolerance):
    """ Same as above, but for vectors.
        x_vec, y_vec are 1d numpy arrays
     """
    bools = [abs(el) < tolerance for el in x_vec - y_vec]
    return all(bools), bools

def is_same_hyperplane_nocomp(a1, b1, a2, b2, tolerance=global_tolerance):
    """ Check same hyperplane when not comparison form """

    # Check that neither a is zero
    a1_zero = fuzzy_equal(onp.linalg.norm(a1), 0, tolerance=tolerance)
    a2_zero = fuzzy_equal(onp.linalg.norm(a1), 0, tolerance=tolerance)
    b1_zero = fuzzy_equal(b1, 0.0, tolerance=tolerance)
    b2_zero = fuzzy_equal(b2, 0.0, tolerance=tolerance)

    # If exactly one is zero, then they can't be equal
    if (a1_zero != a2_zero) or (b1_zero != b2_zero):
        return False

    # Then find if there's a ratio between the two
    first_nonzero_idx = None
    for i, el in enumerate(a1):
        if not fuzzy_equal(el , 0.0, tolerance):
            first_nonzero_idx = i
            break
        if first_nonzero_idx == None:
            return False
    two_one_ratio = a2[first_nonzero_idx] / a1[first_nonzero_idx]

    # If this ratio is zero, return False
    if fuzzy_equal(two_one_ratio, 0.0, tolerance=tolerance):
        return False

    # If the vectors aren't parallel, return False
    if not fuzzy_vector_equal(two_one_ratio * a1, a2, tolerance=tolerance):
        return False

    # If the biases aren't equal, return false, o.w. return True
    return fuzzy_equal(two_one_ratio * b1, b2, tolerance=tolerance)

def is_same_hyperplane(a1, b1, a2, b2, tolerance=global_tolerance):
    """ Given two hyperplanes of the form <a1, x> =b1, <a2, x> =b2
        this returns true if the two define the same hyperplane.

        Only works if these two are in 'comparison form'
    """
    # assert that we're in comparison form
    # --- check hyperplane 1 first
    for (a, b) in [(a1, b1), (a2, b2)]:
        if abs(b) < tolerance: # b ~ 0, then ||a|| ~ 1
            assert fuzzy_equal(onp.linalg.norm(a), 1, tolerance=tolerance)
        else:
            # otherwise abs(b) ~ 1
            assert fuzzy_equal(abs(b), 1.0, tolerance=tolerance)

    # First check that b1, b2 are either +-1, or 0
    if not fuzzy_equal(abs(b1), abs(b2), tolerance=tolerance):
        return False

    # if b's are zero, then vectors need to be equal up to -1 factor
    if fuzzy_equal(b1, 0, tolerance=tolerance):
        return (fuzzy_vector_equal(a1, a2, tolerance=tolerance) or
                fuzzy_vector_equal(a1, -a2, tolerance=tolerance))


    # if b's are different signs, then a1/b1 ~ a2/b2
    return fuzzy_vector_equal(a1 / b1, a2 / b2, tolerance=tolerance)

def is_same_tight_constraint(a1, b1, a2, b2, tolerance=global_tolerance):
    """ Given tight constraint of the form <a1, x> <=b1, <a2, x> <=b2
        this returns true if the two define the same tight constraint.

        Only works if these two are in 'comparison form'
    """
    # assert that we're in comparison form
    # --- check hyperplane 1 first
    for (a, b) in [(a1, b1), (a2, b2)]:
        if abs(b) < tolerance: # b ~ 0, then ||a|| ~ 1
            assert fuzzy_equal(onp.linalg.norm(a), 1, tolerance=tolerance)
        else:
            # otherwise abs(b) ~ 1
            assert fuzzy_equal(abs(b), 1.0, tolerance=tolerance)

    # First check that b1, b2 are either +-1, or 0
    if not fuzzy_equal(abs(b1), abs(b2), tolerance=tolerance):
        return False

    # if b's are zero, then vectors need to be equal up to -1 factor
    if fuzzy_equal(b1, 0, tolerance=tolerance):
        return (fuzzy_vector_equal(a1, a2, tolerance=tolerance) or
                fuzzy_vector_equal(a1, -a2, tolerance=tolerance))


    # check if a1 approx = a2 and b1 approx = b2
    return fuzzy_vector_equal(a1, a2, tolerance=tolerance) and fuzzy_equal(b1, b2, tolerance=tolerance)

def comparison_form(A, b, tolerance=global_tolerance):
    """ Given polytope Ax<= b
        Convert each constraint into a_i^Tx <= +-1, 0
        If b_i=0, normalize a_i
        Then sort rows of A lexicographically

    A is a 2d numpy array of shape (m,n)
    b is a 1d numpy array of shape (m)
    """
    raise DeprecationWarning("DON'T DO THIS OUTSIDE OF BATCH ")
    m, n = A.shape
    # First scale all constraints to have b = +-1, 0
    b_abs = onp.abs(b)
    rows_to_scale = (b_abs > tolerance).astype(int)
    rows_to_normalize = 1 - rows_to_scale
    scale_factor = onp.ones(m) - rows_to_scale + b_abs

    b_scaled = (b / scale_factor)
    a_scaled = A / scale_factor[:, None]

    rows_to_scale = 1

    # Only do the row normalization if you have to
    if onp.sum(rows_to_normalize) > 0:
        row_norms = onp.linalg.norm(a_scaled, axis=1)
        norm_scale_factor = (onp.ones(m) - rows_to_normalize +
                             row_norms * rows_to_normalize)
        a_scaled = a_scaled / norm_scale_factor[:, None]


    # Sort scaled version
    sort_indices = onp.lexsort(a_scaled.T)
    sorted_a = a_scaled[sort_indices]
    sorted_b = b_scaled[sort_indices]

    return sorted_a, sorted_b

def plot_hyperplanes(ub_A, ub_b, styles=None, ax=None, check=False, tol=global_tolerance):
    ''' Plots all hyperplanes defined by each constraint of ub_A and ub_b'''
    if ax is None:
        ax = plt.axes() 

    if styles is None:
        styles = ['-' for _ in range(0, onp.shape(ub_A)[0])]

    if check:
      print("removing duplicate hyperplanes...")
        
    v_A = []
    v_b = []
    for a, b, style in zip(ub_A, ub_b, styles):
        m = -a[0]/a[1]
        intercept = b/a[1]
        if check:
          flag = False
          for a2, b2 in zip(v_A, v_b):
            if flag:
              break
            if is_same_hyperplane_nocomp(a, b, a2, b2, tolerance=tol):
              flag = True
          if flag:
            continue
        plot_line(m, intercept, style, ax)
        v_A.append(a)
        v_b.append(b)
        
def plot_line(slope, intercept, style, ax=None):
    """Plot a line from slope and intercept"""
    if ax is None:
        ax = plt.axes() 
    axes = plt.gca()
    x_vals = onp.array(axes.get_xlim())
    y_vals = intercept + slope * x_vals
    ax.plot(x_vals, y_vals, style, c='black')
    


##########################################################################
#                                                                        #
#                   NEURAL CONFIG UTILITIES                              #
#                                                                        #
##########################################################################

def get_new_configs(old_configs, tight_index):
    flip_i, flip_j = index_to_config_coord(old_configs, tight_index)
    new_configs = copy.deepcopy(old_configs)
    new_configs[flip_i][flip_j] = int(1 - new_configs[flip_i][flip_j])
    return new_configs


def config_hamming_distance(config1, config2):
    """ Given two configs (as a list of floattensors, where all elements are
        0.0 or 1.0) computes the hamming distance between them
    """

    hamming_dist = 0
    for comp1, comp2 in zip(config1, config2):
        uneqs = comp1.type(torch.uint8) != comp2.type(torch.uint8)
        hamming_dist += uneqs.sum().numpy()
    return hamming_dist

def string_hamming_distance(str1, str2):
    assert len(str1) == len(str2)
    dist = sum(c1 != c2 for c1, c2 in zip(str1, str2))
    return dist

def hamming_indices(str1, str2):
    assert len(str1) == len(str2)
    # list = [c1 != c2 for c1, c2 in zip(str1, str2)]
    # index = list.index(True)
    indices = [index for index, (c1, c2) in enumerate(zip(str1, str2)) if c1 != c2]

    return indices

def cat_config(conf):
    """ Takes a list of float or uint8 tensors and flattens them into onp.ndarray
    """
    return torch.cat([_.cpu().type(torch.uint8).detach() for _ in conf]).numpy()


def flatten_config(config):
    """ Takes a list of floatTensors where each element is either 1 or 0
        and converts into a string of 1s and 0s.
        Is just a binary representation of the neuron config
    """
    return ''.join(str(_) for _ in cat_config(config))


def index_to_config_coord(config, index):
    """ Given an index of the flattened array, returns the 2d index of where
        this corresponds to the configs
    """
    config_shapes = [_.numel() for _ in config]
    assert index < sum(config_shapes)

    for i, config_len in enumerate(config_shapes):
        if index > config_len - 1:
            index -= config_len
        else:
            return (i, index)
        
        
def binarize_relu_configs(relu_configs):
    """ Takes a list of relu configs and turns them into one long binary string
        (each element i of relu_configs is assumed to be an array of relu acts at layer i)
    """
    long_code = [element.data.numpy() for code in relu_configs for element in code]
    bin_code = onp.asarray(long_code).astype(int)
    return bin_code


### VAE utils

def gaussian_kl(mu, sigmasq):
  """KL divergence from a diagonal Gaussian to the standard Gaussian."""
  return -0.5 * np.sum(1. + np.log(sigmasq) - mu**2. - sigmasq)

def gaussian_sample(rng, mu, sigmasq):
  """Sample a diagonal Gaussian."""
  return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)

def bernoulli_logpdf(logits, x):
  """Bernoulli log pdf of data x given logits."""
  return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits))

def image_grid(nrow, ncol, imagevecs, imshape):
  """Reshape a stack of image vectors into an image grid for plotting."""
  images = iter(imagevecs.reshape((-1,) + imshape))
  return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
                    for _ in range(nrow)]).T

