diff --git a/README.md b/README.md
index f8aaedb..c7323ae 100755
--- a/README.md
+++ b/README.md
@@ -248,7 +248,7 @@ will run DeepZ for analyzing property 9 of ACASXu benchmarks. The ACASXU network
 Geometric analysis
 
 ```
-python3 . --netname ../nets/pytorch/mnist/convBig__DiffAI.pyt --geometric --geometric_config ../deepg/examples/example1/config.txt --num_params 1 --dataset mnist
+python3 . --netname ../nets/pytorch/mnist/convBig__DiffAI.pyt --geometric --geometric_config ../deepg/code/examples/example1/config.txt --num_params 1 --dataset mnist
 ```
 will on the fly generate geometric perturbed images and evaluate the network against them. For more information on the geometric configuration file please see [Format of the configuration file in DeepG](https://github.com/eth-sri/deepg#format-of-configuration-file).
 
diff --git a/requirements.txt b/requirements.txt
index 526d25e..de73989 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,6 @@
 numpy
-tensorflow
+scipy
+matplotlib
 onnx==1.5.0
 pycddlib
 tqdm
diff --git a/tf_verify/__main__.py b/tf_verify/__main__.py
index 377d739..3e8c7c0 100644
--- a/tf_verify/__main__.py
+++ b/tf_verify/__main__.py
@@ -40,6 +40,7 @@ import logging
 import torch
 import spatial
 from copy import deepcopy
+import math
 
 #ZONOTOPE_EXTENSION = '.zt'
 EPS = 10**(-9)
@@ -97,6 +98,51 @@ def show_ascii_spec(lb, ub, n_rows, n_cols, n_channels):
         print('  |  ')
     print('==================================================================')
 
+def create_cut_model( netname, dataset, im, model_name, geometric_params ):
+    import tensorflow as tf
+    from read_net_file import read_tensorflow_net
+    from clever_wolf import CutModel
+    sess = tf.Session()
+    filename, file_extension = os.path.splitext(netname)
+    is_trained_with_pytorch = file_extension==".pyt"
+
+    num_pixels = geometric_params
+
+    model, is_conv, means, stds, layers = read_tensorflow_net(netname, num_pixels, is_trained_with_pytorch)
+    pixel_size = np.array( [ 1.0 ] * num_pixels )
+    pgd_means = np.zeros( ( num_pixels, 1 ) ) 
+    pgd_stds = np.ones( ( num_pixels, 1 ) ) 
+
+    zeros = np.zeros((num_pixels))
+    ones = np.ones((num_pixels))
+    if is_trained_with_pytorch:
+        normalize( zeros, means, stds, is_conv )
+        normalize( ones, means, stds, is_conv )
+
+        if dataset == 'mortgage' or dataset == 'acasxu':
+            pgd_means[ : , 0 ] = means
+            pgd_stds[ : , 0 ] = stds
+            pixel_size =  np.array( [ 1.0 ] * num_pixels ) / stds 
+        else:
+            # TODO Hack - works only on MNIST and CIFAR10 and mortgage and ACAS Xu
+            assert False
+    else:
+        assert dataset == 'mnist'
+        im_copy = np.copy( im )
+
+    print( 'Model created' )
+    tf_out = tf.get_default_graph().get_tensor_by_name( model.name )
+    tf_in = tf.get_default_graph().get_tensor_by_name( 'x:0' )
+    print( 'Tensors created' )
+
+    out = sess.run( tf_out, feed_dict={ tf_in: im_copy } )
+    print( 'Tf out computed' )
+    if model_name is None:
+        cut_model = CutModel( sess, tf_in, tf_out, np.argmax( out ), pixel_size )
+    else:
+        cut_model = CutModel.load( model_name, sess, tf_in, tf_out, np.argmax( out ) )
+    print( 'Cut model created' )
+    return cut_model, is_conv, means, stds, im_copy, pgd_means, pgd_stds, layers, zeros, ones
 
 def normalize(image, means, stds, dataset):
     # normalization taken out of the network
@@ -272,8 +318,6 @@ def acasxu_recursive(specLB, specUB, max_depth=10, depth=0):
         result = failed_already.value and result and acasxu_recursive([lb if i != index else m for i, lb in enumerate(specLB)], specUB, max_depth, depth + 1)
         return result
 
-
-
 def get_tests(dataset, geometric):
     if geometric:
         csvfile = open('../deepg/code/datasets/{}_test.csv'.format(dataset), 'r')
@@ -314,6 +358,10 @@ parser.add_argument('--numproc', type=int, default=config.numproc,  help='number
 parser.add_argument('--sparse_n', type=int, default=config.sparse_n,  help='Number of variables to group by k-ReLU')
 parser.add_argument('--use_default_heuristic', type=str2bool, default=config.use_default_heuristic,  help='whether to use the area heuristic for the DeepPoly ReLU approximation or to always create new noise symbols per relu for the DeepZono ReLU approximation')
 parser.add_argument('--use_milp', type=str2bool, default=config.use_milp,  help='whether to use milp or not')
+parser.add_argument('--eot', action='store_true', default=config.debug, help='Whether to do EOT baseline')
+parser.add_argument('--geom_baseline', type=int, nargs="+", default=[],  help='Whether to do Split baseline')
+parser.add_argument('--geom_box', type=int, nargs="+", default=[],  help='Whether to do Split baseline')
+parser.add_argument('--geom_box_its', type=int, default=20,  help='Number of geom box iterations')
 parser.add_argument('--refine_neurons', action='store_true', default=config.refine_neurons, help='whether to refine intermediate neurons')
 parser.add_argument('--mean', nargs='+', type=float, default=config.mean, help='the mean used to normalize the data with')
 parser.add_argument('--std', nargs='+', type=float, default=config.std, help='the standard deviation used to normalize the data with')
@@ -324,6 +372,7 @@ parser.add_argument('--num_tests', type=int, default=config.num_tests, help='Num
 parser.add_argument('--from_test', type=int, default=config.from_test, help='Number of images to test')
 parser.add_argument('--debug', action='store_true', default=config.debug, help='Whether to display debug info')
 parser.add_argument('--attack', action='store_true', default=config.attack, help='Whether to attack')
+parser.add_argument('--skip_geom_ver', action='store_true', default=config.attack, help='Whether to attack')
 parser.add_argument('--geometric', '-g', dest='geometric', default=config.geometric, action='store_true', help='Whether to do geometric analysis')
 parser.add_argument('--input_box', default=config.input_box,  help='input box to use')
 parser.add_argument('--output_constraints', default=config.output_constraints, help='custom output constraints to check')
@@ -342,7 +391,6 @@ args = parser.parse_args()
 for k, v in vars(args).items():
     setattr(config, k, v)
 config.json = vars(args)
-
 if config.specnumber and not config.input_box and not config.output_constraints:
     config.input_box = '../data/acasxu/specs/acasxu_prop_' + str(config.specnumber) + '_input_prenormalized.txt'
     config.output_constraints = '../data/acasxu/specs/acasxu_prop_' + str(config.specnumber) + '_constraints.txt'
@@ -588,21 +636,42 @@ elif zonotope_bool:
 
 elif config.geometric:
     from geometric_constraints import *
+    from geometric_symadex import * 
     total, attacked, standard_correct, tot_time = 0, 0, 0, 0
     correct_box, correct_poly = 0, 0
     cver_box, cver_poly = [], []
     if config.geometric_config:
         transform_attack_container = get_transform_attack_container(config.geometric_config)
-        for i, test in enumerate(tests):
-            if config.from_test and i < config.from_test:
+        tf_underbox_tens = gen_tf_underapprox_box( config.num_params )
+        #tf_eot_tens = create_EoT_tensors( eran, config, means, stds )
+        for img, test in enumerate(tests):
+            if config.from_test and img < config.from_test:
                 continue
-
-            if config.num_tests is not None and i >= config.num_tests:
+            if config.num_tests is not None and img >= config.from_test + config.num_tests:
                 break
-            set_transform_attack_for(transform_attack_container, i, config.attack, config.debug)
+            #if img not in [0,4,10,11,12,14,28,32,38,43,44,45,46,51,54,61,64,66,78,79,82,96,97]:
+            #    continue
+            print( 'Test {}:'.format(img) )
+            test_img = np.array( [ float(t) for t in test[1:] ] ).copy()
+            normalize( test_img, means, stds, dataset )
+            label ,_ = eran.quick_eval( test_img )
+
+            if label != int(test[0]):
+                print('Label {}, but true label is {}, skipping...'.format(label, int(test[0])))
+                print('Standard accuracy: {} percent'.format(standard_correct / float(img + 1) * 100))
+                continue
+            else:
+                standard_correct += 1
+                print('Standard accuracy: {} percent'.format(standard_correct / float(img + 1) * 100))
+
+            st = time.time()
+            reinit( transform_attack_container )
+            lbbox, ubbox = get_geom_box( transform_attack_container )
+            set_transform_attack_for(transform_attack_container, img, config.attack, config.skip_geom_ver, config.debug)
             attack_params = get_attack_params(transform_attack_container)
             attack_images = get_attack_images(transform_attack_container)
-            print('Test {}:'.format(i))
+            en = time.time()
+            print('Attack+Constr time: %.2fs' % (en-st), flush=True )
 
             image = np.float64(test[1:])
             if config.dataset == 'mnist' or config.dataset == 'fashion':
@@ -616,18 +685,17 @@ elif config.geometric:
             normalize(spec_lb, means, stds, config.dataset)
             normalize(spec_ub, means, stds, config.dataset)
 
-            label, nn, nlb, nub,_,_ = eran.analyze_box(spec_lb, spec_ub, 'deeppoly', config.timeout_lp, config.timeout_milp,
-                                                   config.use_default_heuristic)
+            label, nn, nlb, nub, _, _, _ = eran.analyze_box(spec_lb, spec_ub, 'deeppoly', config.timeout_lp, config.timeout_milp, config.use_default_heuristic)
             print('Label: ', label)
 
             begtime = time.time()
             if label != int(test[0]):
                 print('Label {}, but true label is {}, skipping...'.format(label, int(test[0])))
-                print('Standard accuracy: {} percent'.format(standard_correct / float(i + 1) * 100))
+                print('Standard accuracy: {} percent'.format(standard_correct / float(img + 1) * 100))
                 continue
             else:
                 standard_correct += 1
-                print('Standard accuracy: {} percent'.format(standard_correct / float(i + 1) * 100))
+                print('Standard accuracy: {} percent'.format(standard_correct / float(img + 1) * 100))
 
             dim = n_rows * n_cols * n_channels
 
@@ -636,31 +704,103 @@ elif config.geometric:
 
             attack_imgs, checked, attack_pass = [], [], 0
             cex_found = False
+            succ_attacks = 0
+            output_size = eran.quick_eval(spec_lb)[1].shape[1]
+            idxs = [[] for _ in range(output_size)]
+            boxes = [None for _ in range(output_size)]
             if config.attack:
-                for j in tqdm(range(0, len(attack_params))):
-                    params = attack_params[j]
-                    values = np.array(attack_images[j])
-
-                    attack_lb = values[::2]
-                    attack_ub = values[1::2]
-
-                    normalize(attack_lb, means, stds, config.dataset)
-                    normalize(attack_ub, means, stds, config.dataset)
-                    attack_imgs.append((params, attack_lb, attack_ub))
-                    checked.append(False)
-
-                    predict_label, _, _, _,_ = eran.analyze_box(
-                        attack_lb[:dim], attack_ub[:dim], 'deeppoly',
-                        config.timeout_lp, config.timeout_milp, config.use_default_heuristic)
-                    if predict_label != int(test[0]):
-                        print('counter-example, params: ', params, ', predicted label: ', predict_label)
-                        cex_found = True
-                        break
-                    else:
-                        attack_pass += 1
-            print('tot attacks: ', len(attack_imgs))
-
-            lines = get_transformations(transform_attack_container)
+                lbs = np.array(attack_images)[:,::2]
+                if dataset == 'mnist'  or dataset == 'fashion':
+                    lbs = (lbs - means[0])/stds[0]
+                else:
+                    assert dataset == 'cifar10'
+                    lbs = lbs.reshape( -1, 1024, 3 )
+                    lbs = ( lbs - means.reshape(1,3) )/ stds.reshape(1,3)
+                    lbs = lbs.reshape( -1, 3072 )
+ 
+                eran_sess = eran.tf_session
+                tf_y = eran_sess.graph.get_operation_by_name(model.op.name).outputs[0]
+                tf_x = eran_sess.graph.get_operations()[0].outputs[0]
+                out = eran_sess.run( tf_y, feed_dict={tf_x:lbs[:200].reshape(-1)} ) 
+                labels = np.argmax( out, axis=1 )
+                batch_size = 200
+                batches = math.ceil( len(attack_params)/batch_size )
+                for j in tqdm(range(0, batches)):
+                    st_idx = j * batch_size
+                    en_idx = min( (j + 1) * batch_size, len(attack_params) )
+                    attack_imgs = attack_imgs + list( zip( attack_params[st_idx:en_idx], lbs[st_idx:en_idx] ) )
+                    out = eran_sess.run( tf_y, feed_dict={tf_x:lbs[st_idx:en_idx].reshape(-1)} ) 
+                    labels = np.argmax( out, axis=1 )
+                    correct = labels == int(test[0])
+                    succ_attacks += np.sum( np.logical_not( correct )  )
+                    
+                    adex_labels = set( np.unique ( labels ).tolist() ) - set( [int(test[0])] )
+                    for adex in adex_labels:
+                        idxs[adex] += ( (np.where( labels == adex ))[0] + j * batch_size ).tolist()
+                        params = np.array( attack_params[st_idx:en_idx] ) [ labels == adex ]
+                        params_min = np.min( params, axis=0 )
+                        params_max = np.max( params, axis=0 )
+                        if boxes[adex] is None:
+                            boxes[adex] = (params_min, params_max)
+                        else:
+                            lb_box, ub_box = boxes[adex]
+                            boxes[adex] = ( np.minimum( params_min, lb_box ) , np.maximum( params_max, ub_box ) )
+            
+            print('tot attacks: ', succ_attacks, '/', len(attack_imgs), ';', list(map(len,idxs)))
+            print('attack_boxes:', boxes )
+            for cl, attacks in enumerate(idxs):
+                if len( attacks ) > 0:
+                    #import pdb; pdb.set_trace()
+                    #f = np.load( '4d.npz' ) 
+                    #boxes[cl] = f['lb'], f['ub']
+                    print( boxes[cl][0], boxes[cl][1], 'target:', cl, 'orig:', int(test[0]) )
+                    #exp_over_transform( eran.tf_session, eran.model, config, transform_attack_container, image, boxes[cl][0], boxes[cl][1], 100, 20 )
+                    if len( config.geom_baseline ) > 0:
+                        assert len( config.geom_baseline ) == config.num_params
+                        # Baseline
+                        st = time.time()
+                        lbs, ubs = generate_splits( boxes[cl][0], boxes[cl][1], config.geom_baseline )
+                        res = []
+                        par = []
+     
+                        deltas = (boxes[cl][1] - boxes[cl][0]) / config.geom_baseline
+                        for j in range( lbs.shape[1] ):
+                            out = verify_geometric_box(eran, config, means, stds, transform_attack_container, img, cl, lbs[:,j].copy(), ubs[:,j].copy(), 1, W=None, b=None, update_geom=False, target=float('nan'))
+                            bound = min( map(lambda a : -a[-2], out[-1][0]) )
+                            res.append( bound > 0 )
+                            par.append( (lbs[:,j], ubs[:,j]) )
+
+                        en = time.time()
+                        volume = np.prod(deltas) * np.sum( res ) 
+                        t = en - st 
+                        str_baseline_size = "".join( [str(el) + 'x' for el in config.geom_baseline] )[:-1]
+                        print( 'Baseline', str_baseline_size, 'Time:', t , 's', volume, flush=True )
+                    if config.eot:
+                        st = time.time()
+                        lam_mid, lb_eot, ub_eot, perc_eot, vol_eot = get_eot_lam( eran, tf_eot_tens, config, transform_attack_container, image, lbbox, ubbox, means, stds, cl )
+                        en = time.time()
+                        t = en - st
+                        print ( 'EOT 1x1x1 Time:',  t , 's', vol_eot, 'EOT percent:', perc_eot, flush=True )
+                    if len( config.geom_box ) > 0:
+                        assert len( config.geom_box ) == config.num_params
+                        # Mine
+                        st = time.time()
+                        lbs, ubs = generate_splits( boxes[cl][0], boxes[cl][1], config.geom_box )
+                        vol = 0 
+                        for j in range( lbs.shape[1] ):
+                            lb_under, ub_under = create_geom_poly(eran, tf_underbox_tens, config, means, stds, transform_attack_container, img, lbs[:,j].copy(), ubs[:,j].copy(), cl, config.geom_box_its)
+                            if lb_under is None:
+                                continue
+                            print( 'Vol:', np.prod( ub_under - lb_under ), flush=True )
+                            vol += np.prod( ub_under - lb_under )
+
+                        en = time.time()
+                        t = en - st
+                        str_min_size = "".join( [str(el) + 'x' for el in config.geom_box] )[:-1]
+                        print( 'Mine', str_min_size, 'Time:', t , 's', vol, flush=True)
+            continue 
+            set_transform_attack_for(transform_attack_container, img, config.attack, config.skip_geom_ver, config.debug)
+            lines = get_transformations(transform_attack_container) 
             print('Number of lines: ', len(lines))
             assert len(lines) % k == 0
 
@@ -673,102 +813,102 @@ elif config.geometric:
             lexpr_dim, uexpr_dim = [], []
 
             ver_chunks_box, ver_chunks_poly, tot_chunks = 0, 0, 0
+            if not config.skip_geom_ver:
+                for i, line in enumerate(lines):
+                    if i % k < config.num_params:
+                        # read specs for the parameters
+                        values = line
+                        assert len(values) == 2
+                        param_idx = i % k
+                        spec_lb[dim + param_idx] = values[0]
+                        spec_ub[dim + param_idx] = values[1]
+                        if config.debug:
+                            print('parameter %d: [%.4f, %.4f]' % (param_idx, values[0], values[1]))
+                    elif i % k == config.num_params:
+                        # read interval bounds for image pixels
+                        values = line
+                        spec_lb[:dim] = values[::2]
+                        spec_ub[:dim] = values[1::2]
+                        # if config.debug:
+                        #     show_ascii_spec(spec_lb, spec_ub)
+                    elif i % k < k - 1:
+                        # read polyhedra constraints for image pixels
+                        tokens = line
+                        assert len(tokens) == 2 + 2 * config.num_params
 
-            for i, line in enumerate(lines):
-                if i % k < config.num_params:
-                    # read specs for the parameters
-                    values = line
-                    assert len(values) == 2
-                    param_idx = i % k
-                    spec_lb[dim + param_idx] = values[0]
-                    spec_ub[dim + param_idx] = values[1]
-                    if config.debug:
-                        print('parameter %d: [%.4f, %.4f]' % (param_idx, values[0], values[1]))
-                elif i % k == config.num_params:
-                    # read interval bounds for image pixels
-                    values = line
-                    spec_lb[:dim] = values[::2]
-                    spec_ub[:dim] = values[1::2]
-                    # if config.debug:
-                    #     show_ascii_spec(spec_lb, spec_ub)
-                elif i % k < k - 1:
-                    # read polyhedra constraints for image pixels
-                    tokens = line
-                    assert len(tokens) == 2 + 2 * config.num_params
-
-                    bias_lower, weights_lower = tokens[0], tokens[1:1 + config.num_params]
-                    bias_upper, weights_upper = tokens[config.num_params + 1], tokens[2 + config.num_params:]
-
-                    assert len(weights_lower) == config.num_params
-                    assert len(weights_upper) == config.num_params
-
-                    lexpr_cst.append(bias_lower)
-                    uexpr_cst.append(bias_upper)
-                    for j in range(config.num_params):
-                        lexpr_dim.append(dim + j)
-                        uexpr_dim.append(dim + j)
-                        lexpr_weights.append(weights_lower[j])
-                        uexpr_weights.append(weights_upper[j])
-                else:
-                    assert (len(line) == 0)
-                    for p_idx in range(config.num_params):
-                        lexpr_cst.append(spec_lb[dim + p_idx])
-                        for l in range(config.num_params):
-                            lexpr_weights.append(0)
-                            lexpr_dim.append(dim + l)
-                        uexpr_cst.append(spec_ub[dim + p_idx])
-                        for l in range(config.num_params):
-                            uexpr_weights.append(0)
-                            uexpr_dim.append(dim + l)
-                    normalize(spec_lb[:dim], means, stds, config.dataset)
-                    normalize(spec_ub[:dim], means, stds, config.dataset)
-                    normalize_poly(config.num_params, lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,
-                                   uexpr_dim, means, stds, config.dataset)
-
-                    for attack_idx, (attack_params, attack_lb, attack_ub) in enumerate(attack_imgs):
-                        ok_attack = True
-                        for j in range(num_pixels):
-                            low, up = lexpr_cst[j], uexpr_cst[j]
-                            for idx in range(config.num_params):
-                                low += lexpr_weights[j * config.num_params + idx] * attack_params[idx]
-                                up += uexpr_weights[j * config.num_params + idx] * attack_params[idx]
-                            if low > attack_lb[j] + EPS or attack_ub[j] > up + EPS:
-                                ok_attack = False
-                        if ok_attack:
-                            checked[attack_idx] = True
-                            # print('checked ', attack_idx)
-                    if config.debug:
-                        print('Running the analysis...')
-
-                    t_begin = time.time()
-                    perturbed_label_poly, _, _, _,_ = eran.analyze_box(
-                        spec_lb, spec_ub, 'deeppoly',
-                        config.timeout_lp, config.timeout_milp, config.use_default_heuristic, None,
-                        lexpr_weights, lexpr_cst, lexpr_dim,
-                        uexpr_weights, uexpr_cst, uexpr_dim,
-                        expr_size)
-                    perturbed_label_box, _, _, _,_ = eran.analyze_box(
-                        spec_lb[:dim], spec_ub[:dim], 'deeppoly',
-                        config.timeout_lp, config.timeout_milp, config.use_default_heuristic)
-                    t_end = time.time()
-
-                    print('DeepG: ', perturbed_label_poly, '\tInterval: ', perturbed_label_box, '\tlabel: ', label,
-                          '[Time: %.4f]' % (t_end - t_begin))
-
-                    tot_chunks += 1
-                    if perturbed_label_box != label:
-                        ok_box = False
-                    else:
-                        ver_chunks_box += 1
+                        bias_lower, weights_lower = tokens[0], tokens[1:1 + config.num_params]
+                        bias_upper, weights_upper = tokens[config.num_params + 1], tokens[2 + config.num_params:]
+
+                        assert len(weights_lower) == config.num_params
+                        assert len(weights_upper) == config.num_params
 
-                    if perturbed_label_poly != label:
-                        ok_poly = False
+                        lexpr_cst.append(bias_lower)
+                        uexpr_cst.append(bias_upper)
+                        for j in range(config.num_params):
+                            lexpr_dim.append(dim + j)
+                            uexpr_dim.append(dim + j)
+                            lexpr_weights.append(weights_lower[j])
+                            uexpr_weights.append(weights_upper[j])
                     else:
-                        ver_chunks_poly += 1
+                        assert (len(line) == 0)
+                        for p_idx in range(config.num_params):
+                            lexpr_cst.append(spec_lb[dim + p_idx])
+                            for l in range(config.num_params):
+                                lexpr_weights.append(0)
+                                lexpr_dim.append(dim + l)
+                            uexpr_cst.append(spec_ub[dim + p_idx])
+                            for l in range(config.num_params):
+                                uexpr_weights.append(0)
+                                uexpr_dim.append(dim + l)
+                        normalize(spec_lb[:dim], means, stds, config.dataset)
+                        normalize(spec_ub[:dim], means, stds, config.dataset)
+                        normalize_poly(config.num_params, lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,
+                                       uexpr_dim, means, stds, config.dataset)
+
+                        for attack_idx, (attack_params, attack_lb, attack_ub) in enumerate(attack_imgs):
+                            ok_attack = True
+                            for j in range(num_pixels):
+                                low, up = lexpr_cst[j], uexpr_cst[j]
+                                for idx in range(config.num_params):
+                                    low += lexpr_weights[j * config.num_params + idx] * attack_params[idx]
+                                    up += uexpr_weights[j * config.num_params + idx] * attack_params[idx]
+                                if low > attack_lb[j] + EPS or attack_ub[j] > up + EPS:
+                                    ok_attack = False
+                            if ok_attack:
+                                checked[attack_idx] = True
+                                # print('checked ', attack_idx)
+                        if config.debug:
+                            print('Running the analysis...')
 
-                    lexpr_cst, uexpr_cst = [], []
-                    lexpr_weights, uexpr_weights = [], []
-                    lexpr_dim, uexpr_dim = [], []
+                        t_begin = time.time()
+                        perturbed_label_poly, _, _, _, _, _ = eran.analyze_box(
+                            spec_lb, spec_ub, 'deeppoly',
+                            config.timeout_lp, config.timeout_milp, config.use_default_heuristic, None,
+                            lexpr_weights, lexpr_cst, lexpr_dim,
+                            uexpr_weights, uexpr_cst, uexpr_dim,
+                            expr_size)
+                        perturbed_label_box, _, _, _, _, _ = eran.analyze_box(
+                            spec_lb[:dim], spec_ub[:dim], 'deeppoly',
+                            config.timeout_lp, config.timeout_milp, config.use_default_heuristic)
+                        t_end = time.time()
+
+                        print('DeepG: ', perturbed_label_poly, '\tInterval: ', perturbed_label_box, '\tlabel: ', label,
+                              '[Time: %.4f]' % (t_end - t_begin))
+
+                        tot_chunks += 1
+                        if perturbed_label_box != label:
+                            ok_box = False
+                        else:
+                            ver_chunks_box += 1
+
+                        if perturbed_label_poly != label:
+                            ok_poly = False
+                        else:
+                            ver_chunks_poly += 1
+
+                        lexpr_cst, uexpr_cst = [], []
+                        lexpr_weights, uexpr_weights = [], []
+                        lexpr_dim, uexpr_dim = [], []
 
             total += 1
             if ok_box:
@@ -776,17 +916,18 @@ elif config.geometric:
             if ok_poly:
                 correct_poly += 1
             if cex_found:
-                assert (not ok_box) and (not ok_poly)
+                if not config.skip_geom_ver:
+                    assert (not ok_box) and (not ok_poly)
                 attacked += 1
-            cver_poly.append(ver_chunks_poly / float(tot_chunks))
-            cver_box.append(ver_chunks_box / float(tot_chunks))
+            if not config.skip_geom_ver:
+                cver_poly.append(ver_chunks_poly / float(tot_chunks))
+                cver_box.append(ver_chunks_box / float(tot_chunks))
             tot_time += time.time() - begtime
 
             print('Verified[box]: {}, Verified[poly]: {}, CEX found: {}'.format(ok_box, ok_poly, cex_found))
-            assert not cex_found or not ok_box, 'ERROR! Found counter-example, but image was verified with box!'
-            assert not cex_found or not ok_poly, 'ERROR! Found counter-example, but image was verified with poly!'
-
-
+            if not config.skip_geom_ver:
+                assert not cex_found or not ok_box, 'ERROR! Found counter-example, but image was verified with box!'
+                assert not cex_found or not ok_poly, 'ERROR! Found counter-example, but image was verified with poly!'
     else:
         for i, test in enumerate(tests):
             if config.from_test and i < config.from_test:
@@ -809,8 +950,8 @@ elif config.geometric:
 
             normalize(spec_lb, means, stds, config.dataset)
             normalize(spec_ub, means, stds, config.dataset)
-
-            label, nn, nlb, nub,_ = eran.analyze_box(spec_lb, spec_ub, 'deeppoly', config.timeout_lp, config.timeout_milp,
+            
+            label, nn, nlb, nub, _, _ = eran.analyze_box(spec_lb, spec_ub, 'deeppoly', config.timeout_lp, config.timeout_milp,
                                                    config.use_default_heuristic)
             print('Label: ', label)
 
@@ -846,9 +987,9 @@ elif config.geometric:
                         attack_imgs.append((params, attack_lb, attack_ub))
                         checked.append(False)
 
-                        predict_label, _, _, _,_ = eran.analyze_box(
+                        predict_label, _, _, _, _, _ = eran.analyze_box(
                             attack_lb[:dim], attack_ub[:dim], 'deeppoly',
-                            config.timeout_lp, config.timeout_milp, config.use_default_heuristic, 0)
+                            config.timeout_lp, config.timeout_milp, config.use_default_heuristic)
                         if predict_label != int(test[0]):
                             print('counter-example, params: ', params, ', predicted label: ', predict_label)
                             cex_found = True
@@ -872,103 +1013,104 @@ elif config.geometric:
 
                 ver_chunks_box, ver_chunks_poly, tot_chunks = 0, 0, 0
 
-                for i, line in enumerate(lines):
-                    if i % k < config.num_params:
-                        # read specs for the parameters
-                        values = np.array(list(map(float, line[:-1].split(' '))))
-                        assert values.shape[0] == 2
-                        param_idx = i % k
-                        spec_lb[dim + param_idx] = values[0]
-                        spec_ub[dim + param_idx] = values[1]
-                        if config.debug:
-                            print('parameter %d: [%.4f, %.4f]' % (param_idx, values[0], values[1]))
-                    elif i % k == config.num_params:
-                        # read interval bounds for image pixels
-                        values = np.array(list(map(float, line[:-1].split(','))))
-                        spec_lb[:dim] = values[::2]
-                        spec_ub[:dim] = values[1::2]
-                        # if config.debug:
-                        #     show_ascii_spec(spec_lb, spec_ub)
-                    elif i % k < k - 1:
-                        # read polyhedra constraints for image pixels
-                        tokens = line[:-1].split(' ')
-                        assert len(tokens) == 2 + 2 * config.num_params + 1
-
-                        bias_lower, weights_lower = float(tokens[0]), list(map(float, tokens[1:1 + config.num_params]))
-                        assert tokens[config.num_params + 1] == '|'
-                        bias_upper, weights_upper = float(tokens[config.num_params + 2]), list(
-                            map(float, tokens[3 + config.num_params:]))
-
-                        assert len(weights_lower) == config.num_params
-                        assert len(weights_upper) == config.num_params
-
-                        lexpr_cst.append(bias_lower)
-                        uexpr_cst.append(bias_upper)
-                        for j in range(config.num_params):
-                            lexpr_dim.append(dim + j)
-                            uexpr_dim.append(dim + j)
-                            lexpr_weights.append(weights_lower[j])
-                            uexpr_weights.append(weights_upper[j])
-                    else:
-                        assert (line == 'SPEC_FINISHED\n')
-                        for p_idx in range(config.num_params):
-                            lexpr_cst.append(spec_lb[dim + p_idx])
-                            for l in range(config.num_params):
-                                lexpr_weights.append(0)
-                                lexpr_dim.append(dim + l)
-                            uexpr_cst.append(spec_ub[dim + p_idx])
-                            for l in range(config.num_params):
-                                uexpr_weights.append(0)
-                                uexpr_dim.append(dim + l)
-                        normalize(spec_lb[:dim], means, stds, config.dataset)
-                        normalize(spec_ub[:dim], means, stds, config.dataset)
-                        normalize_poly(config.num_params, lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,
-                                       uexpr_dim, means, stds, config.dataset)
-
-                        for attack_idx, (attack_params, attack_lb, attack_ub) in enumerate(attack_imgs):
-                            ok_attack = True
-                            for j in range(num_pixels):
-                                low, up = lexpr_cst[j], uexpr_cst[j]
-                                for idx in range(config.num_params):
-                                    low += lexpr_weights[j * config.num_params + idx] * attack_params[idx]
-                                    up += uexpr_weights[j * config.num_params + idx] * attack_params[idx]
-                                if low > attack_lb[j] + EPS or attack_ub[j] > up + EPS:
-                                    ok_attack = False
-                            if ok_attack:
-                                checked[attack_idx] = True
-                                # print('checked ', attack_idx)
-                        if config.debug:
-                            print('Running the analysis...')
-
-                        t_begin = time.time()
-                        perturbed_label_poly, _, _, _ ,_= eran.analyze_box(
-                            spec_lb, spec_ub, 'deeppoly',
-                            config.timeout_lp, config.timeout_milp, config.use_default_heuristic, 0,
-                            lexpr_weights, lexpr_cst, lexpr_dim,
-                            uexpr_weights, uexpr_cst, uexpr_dim,
-                            expr_size)
-                        perturbed_label_box, _, _, _,_ = eran.analyze_box(
-                            spec_lb[:dim], spec_ub[:dim], 'deeppoly',
-                            config.timeout_lp, config.timeout_milp, config.use_default_heuristic, 0)
-                        t_end = time.time()
-
-                        print('DeepG: ', perturbed_label_poly, '\tInterval: ', perturbed_label_box, '\tlabel: ', label,
-                              '[Time: %.4f]' % (t_end - t_begin))
-
-                        tot_chunks += 1
-                        if perturbed_label_box != label:
-                            ok_box = False
-                        else:
-                            ver_chunks_box += 1
-
-                        if perturbed_label_poly != label:
-                            ok_poly = False
+                if not config.skip_geom_ver:
+                    for i, line in enumerate(lines):
+                        if i % k < config.num_params:
+                            # read specs for the parameters
+                            values = np.array(list(map(float, line[:-1].split(' '))))
+                            assert values.shape[0] == 2
+                            param_idx = i % k
+                            spec_lb[dim + param_idx] = values[0]
+                            spec_ub[dim + param_idx] = values[1]
+                            if config.debug:
+                                print('parameter %d: [%.4f, %.4f]' % (param_idx, values[0], values[1]))
+                        elif i % k == config.num_params:
+                            # read interval bounds for image pixels
+                            values = np.array(list(map(float, line[:-1].split(','))))
+                            spec_lb[:dim] = values[::2]
+                            spec_ub[:dim] = values[1::2]
+                            # if config.debug:
+                            #     show_ascii_spec(spec_lb, spec_ub)
+                        elif i % k < k - 1:
+                            # read polyhedra constraints for image pixels
+                            tokens = line[:-1].split(' ')
+                            assert len(tokens) == 2 + 2 * config.num_params + 1
+
+                            bias_lower, weights_lower = float(tokens[0]), list(map(float, tokens[1:1 + config.num_params]))
+                            assert tokens[config.num_params + 1] == '|'
+                            bias_upper, weights_upper = float(tokens[config.num_params + 2]), list(
+                                map(float, tokens[3 + config.num_params:]))
+
+                            assert len(weights_lower) == config.num_params
+                            assert len(weights_upper) == config.num_params
+
+                            lexpr_cst.append(bias_lower)
+                            uexpr_cst.append(bias_upper)
+                            for j in range(config.num_params):
+                                lexpr_dim.append(dim + j)
+                                uexpr_dim.append(dim + j)
+                                lexpr_weights.append(weights_lower[j])
+                                uexpr_weights.append(weights_upper[j])
                         else:
-                            ver_chunks_poly += 1
-
-                        lexpr_cst, uexpr_cst = [], []
-                        lexpr_weights, uexpr_weights = [], []
-                        lexpr_dim, uexpr_dim = [], []
+                            assert (line == 'SPEC_FINISHED\n')
+                            for p_idx in range(config.num_params):
+                                lexpr_cst.append(spec_lb[dim + p_idx])
+                                for l in range(config.num_params):
+                                    lexpr_weights.append(0)
+                                    lexpr_dim.append(dim + l)
+                                uexpr_cst.append(spec_ub[dim + p_idx])
+                                for l in range(config.num_params):
+                                    uexpr_weights.append(0)
+                                    uexpr_dim.append(dim + l)
+                            normalize(spec_lb[:dim], means, stds, config.dataset)
+                            normalize(spec_ub[:dim], means, stds, config.dataset)
+                            normalize_poly(config.num_params, lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,
+                                           uexpr_dim, means, stds, config.dataset)
+
+                            for attack_idx, (attack_params, attack_lb, attack_ub) in enumerate(attack_imgs):
+                                ok_attack = True
+                                for j in range(num_pixels):
+                                    low, up = lexpr_cst[j], uexpr_cst[j]
+                                    for idx in range(config.num_params):
+                                        low += lexpr_weights[j * config.num_params + idx] * attack_params[idx]
+                                        up += uexpr_weights[j * config.num_params + idx] * attack_params[idx]
+                                    if low > attack_lb[j] + EPS or attack_ub[j] > up + EPS:
+                                        ok_attack = False
+                                if ok_attack:
+                                    checked[attack_idx] = True
+                                    # print('checked ', attack_idx)
+                            if config.debug:
+                                print('Running the analysis...')
+
+                            t_begin = time.time()
+                            perturbed_label_poly, _, _, _ , _, _ = eran.analyze_box(
+                                spec_lb, spec_ub, 'deeppoly',
+                                config.timeout_lp, config.timeout_milp, config.use_default_heuristic, None,
+                                lexpr_weights, lexpr_cst, lexpr_dim,
+                                uexpr_weights, uexpr_cst, uexpr_dim,
+                                expr_size)
+                            perturbed_label_box, _, _, _, _, _ = eran.analyze_box(
+                                spec_lb[:dim], spec_ub[:dim], 'deeppoly',
+                                config.timeout_lp, config.timeout_milp, config.use_default_heuristic)
+                            t_end = time.time()
+
+                            print('DeepG: ', perturbed_label_poly, '\tInterval: ', perturbed_label_box, '\tlabel: ', label,
+                                  '[Time: %.4f]' % (t_end - t_begin))
+
+                            tot_chunks += 1
+                            if perturbed_label_box != label:
+                                ok_box = False
+                            else:
+                                ver_chunks_box += 1
+
+                            if perturbed_label_poly != label:
+                                ok_poly = False
+                            else:
+                                ver_chunks_poly += 1
+
+                            lexpr_cst, uexpr_cst = [], []
+                            lexpr_weights, uexpr_weights = [], []
+                            lexpr_dim, uexpr_dim = [], []
 
             total += 1
             if ok_box:
@@ -978,8 +1120,9 @@ elif config.geometric:
             if cex_found:
                 assert (not ok_box) and (not ok_poly)
                 attacked += 1
-            cver_poly.append(ver_chunks_poly / float(tot_chunks))
-            cver_box.append(ver_chunks_box / float(tot_chunks))
+            if not config.skip_geom_ver:
+                cver_poly.append(ver_chunks_poly / float(tot_chunks))
+                cver_box.append(ver_chunks_box / float(tot_chunks))
             tot_time += time.time() - begtime
 
             print('Verified[box]: {}, Verified[poly]: {}, CEX found: {}'.format(ok_box, ok_poly, cex_found))
@@ -990,8 +1133,9 @@ elif config.geometric:
     print('[Box]  Provably robust: %.2f percent, %d/%d' % (100.0 * correct_box / total, correct_box, total))
     print('[Poly] Provably robust: %.2f percent, %d/%d' % (100.0 * correct_poly / total, correct_poly, total))
     print('Empirically robust: %.2f percent, %d/%d' % (100.0 * (total - attacked) / total, total - attacked, total))
-    print('[Box]  Average chunks verified: %.2f percent' % (100.0 * np.mean(cver_box)))
-    print('[Poly]  Average chunks verified: %.2f percent' % (100.0 * np.mean(cver_poly)))
+    if not config.skip_geom_ver:
+        print('[Box]  Average chunks verified: %.2f percent' % (100.0 * np.mean(cver_box)))
+        print('[Poly]  Average chunks verified: %.2f percent' % (100.0 * np.mean(cver_poly)))
     print('Average time: ', tot_time / total)
 
 elif config.input_box is not None:
diff --git a/tf_verify/ai_milp.py b/tf_verify/ai_milp.py
index 5e643cc..5b7ca4b 100644
--- a/tf_verify/ai_milp.py
+++ b/tf_verify/ai_milp.py
@@ -581,7 +581,7 @@ def solver_call(ind):
     return soll, solu, addtoindices, runtime
 
 
-def get_bounds_for_layer_with_milp(nn, LB_N0, UB_N0, layerno, abs_layer_count, output_size, nlb, nub, relu_groups, use_milp, candidate_vars, timeout):
+def get_bounds_for_layer_with_milp(nn, LB_N0, UB_N0, layerno, abs_layer_count, output_size, nlb, nub, relu_groups, use_milp, candidate_vars, timeout, W=None, b=None):
     lbi = nlb[abs_layer_count]
     ubi = nub[abs_layer_count]
     #numlayer = nn.numlayer
diff --git a/tf_verify/analyzer.py b/tf_verify/analyzer.py
index 2428fd4..753095d 100644
--- a/tf_verify/analyzer.py
+++ b/tf_verify/analyzer.py
@@ -185,7 +185,26 @@ class Analyzer:
             output_size = self.ir_list[-1].output_length
         else:
             output_size = self.ir_list[-1].output_length#reduce(lambda x,y: x*y, self.ir_list[-1].bias.shape, 1)
-    
+        
+        nllb = []
+        nlub = []
+        if len(self.nn.specLB) > 784:
+            for i, layer in enumerate( self.nn.layertypes ):
+                if layer in [ 'Conv', 'FC' ]:
+                    coef = np.where(  np.logical_and( np.array(nlb[i]) < 0, np.array(nub[i]) > 0 ) )[0]
+                    llb = np.zeros( (coef.shape[0], len( self.nn.specLB ) + 1 ))
+                    lub = np.zeros( (coef.shape[0], len( self.nn.specLB ) + 1 ))
+                    get_linear_bounds(self.man, element, coef, llb, lub, coef.shape[0], i)
+                    nllb.append( ( i, llb ) )
+                    nlub.append( ( i, lub ) )
+                    
+                    w = nlub[-1][1][:,:-1]
+                    b = nlub[-1][1][:,-1]
+                    out1 = np.maximum( w, 0 ) @ self.nn.specUB + np.minimum( w, 0 ) @ self.nn.specLB + b
+                    out2 = np.array(nub[i])[coef]
+            nllb = nllb[:-1]
+            nlub = nlub[:-1]
+ 
         dominant_class = -1
         if(self.domain=='refinepoly'):
 
@@ -205,6 +224,8 @@ class Analyzer:
 
         label_failed = []
         x = None
+        
+        hps = []
         if self.output_constraints is None:
             candidate_labels = []
             if self.label == -1:
@@ -218,6 +239,19 @@ class Analyzer:
                     adv_labels.append(i)
             else:
                 adv_labels.append(self.prop)   
+            
+            if len(self.nn.specLB) > 784 and len(candidate_labels) == 1:
+                for j in adv_labels:
+                    if candidate_labels[0] == j:
+                        continue
+                    test = np.zeros( len(self.nn.specLB)+1 )
+                    linear_output( self.man, element, candidate_labels[0], j, test, self.use_default_heuristic ) 
+                    w, b = test[:-1], test[-1]
+                    decision = np.maximum( w, 0 ) @ self.nn.specUB + np.minimum( w, 0 ) @ self.nn.specLB + b
+                    #assert (decision < 0) == decision2 
+                    #print( 'lb_our:', decision )
+                    hps.append( ( (w, b), decision, j ) )
+
             for i in candidate_labels:
                 flag = True
                 label = i
@@ -227,8 +261,9 @@ class Analyzer:
                             flag = False
                             break
                     else:
-                        if label!=j and not self.is_greater(self.man, element, label, j, self.use_default_heuristic):
 
+
+                        if label!=j and not self.is_greater(self.man, element, label, j, self.use_default_heuristic):
                             if(self.domain=='refinepoly'):
                                 obj = LinExpr()
                                 obj += 1*var_list[counter+label]
@@ -293,4 +328,4 @@ class Analyzer:
                     dominant_class = False
                     break
         elina_abstract0_free(self.man, element)
-        return dominant_class, nlb, nub, label_failed, x
+        return dominant_class, nlb, nub, label_failed, x, ( hps, nllb, nlub ) 
diff --git a/tf_verify/deeppoly_nodes.py b/tf_verify/deeppoly_nodes.py
index 2563fde..f748473 100644
--- a/tf_verify/deeppoly_nodes.py
+++ b/tf_verify/deeppoly_nodes.py
@@ -82,7 +82,7 @@ class DeeppolyInput:
     def __init__(self, specLB, specUB, input_names, output_name, output_shape,
                  lexpr_weights=None, lexpr_cst=None, lexpr_dim=None,
                  uexpr_weights=None, uexpr_cst=None, uexpr_dim=None,
-                 expr_size=0, spatial_constraints=None):
+                 expr_size=0, spatial_constraints=None, W=None, b=None):
         """
         Arguments
         ---------
@@ -150,7 +150,8 @@ class DeeppolyInput:
             )
 
         add_input_output_information_deeppoly(self, input_names, output_name, output_shape)
-
+        self.W = W
+        self.b = b
 
     def transformer(self, man):
         """
@@ -285,6 +286,10 @@ class DeeppolyNonlinearity:
 
 
 class DeeppolyReluNode(DeeppolyNonlinearity):
+    def __init__(self, input_names, output_name, output_shape, W=None, b=None):
+        super(DeeppolyReluNode, self).__init__(input_names, output_name, output_shape)
+        self.W = W
+        self.b = b
     def transformer(self, nn, man, element, nlb, nub, relu_groups, refine, timeout_lp, timeout_milp, use_default_heuristic, testing):
         """
         transforms element with handle_relu_layer
@@ -302,8 +307,8 @@ class DeeppolyReluNode(DeeppolyNonlinearity):
             abstract element after the transformer
         """
         length = self.output_length
-        if refine:
-            refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_groups, timeout_lp, timeout_milp, use_default_heuristic, 'deeppoly')
+        if refine or not self.W is None:
+            refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_groups, timeout_lp, timeout_milp, use_default_heuristic, 'deeppoly', self.W, self.b)
         else:
             handle_relu_layer(*self.get_arguments(man, element), use_default_heuristic)
         calc_bounds(man, element, nn, nlb, nub, relu_groups, is_refine_layer=True, use_krelu=False)
diff --git a/tf_verify/eran.py b/tf_verify/eran.py
index a5c6e90..4e1cd0e 100644
--- a/tf_verify/eran.py
+++ b/tf_verify/eran.py
@@ -15,12 +15,12 @@
 """
 
 
+import onnxruntime.backend as rt
 from tensorflow_translator import *
 from onnx_translator import *
 from optimizer import *
 from analyzer import *
 
-
 class ERAN:
     def __init__(self, model, session=None, is_onnx = False):
         """
@@ -49,13 +49,31 @@ class ERAN:
         if is_onnx:
             translator = ONNXTranslator(model)
         else:
+            if session is None:
+                session = tf.get_default_session()
             translator = TFTranslator(model, session)
+            self.tf_session = session
+        self.is_onnx = is_onnx
+        self.model = model
         operations, resources = translator.translate()
         self.optimizer  = Optimizer(operations, resources)
         print('This network has ' + str(self.optimizer.get_neuron_count()) + ' neurons.')
     
-    
-    def analyze_box(self, specLB, specUB, domain, timeout_lp, timeout_milp, use_default_heuristic, output_constraints=None, lexpr_weights= None, lexpr_cst=None, lexpr_dim=None, uexpr_weights=None, uexpr_cst=None, uexpr_dim=None, expr_size=0, testing = False,label=-1, prop = -1, spatial_constraints=None):
+    def quick_eval(self, img):
+        if self.is_onnx:
+            runnable = rt.prepare(self.model, 'CPU')
+            dims = self.model.graph.input[0].type.tensor_type.shape.dim
+            dims = [ d.dim_value for d in dims ]
+            img = img.reshape(dims[0], dims[2], dims[3], dims[1]).astype(np.float32)
+            img = img.transpose(0, 3, 1, 2)
+            outputs = runnable.run(img)
+        else:
+            tf_in = tf.get_default_graph().get_tensor_by_name( 'x:0' )
+            outputs = self.tf_session.run( self.model, feed_dict={ tf_in: img } )
+        dominant_class = np.argmax( outputs )
+        return dominant_class, outputs
+
+    def analyze_box(self, specLB, specUB, domain, timeout_lp, timeout_milp, use_default_heuristic, output_constraints=None, lexpr_weights= None, lexpr_cst=None, lexpr_dim=None, uexpr_weights=None, uexpr_cst=None, uexpr_dim=None, expr_size=0, testing = False,label=-1, prop = -1, spatial_constraints=None, W=None, b=None):
         """
         This function runs the analysis with the provided model and session from the constructor, the box specified by specLB and specUB is used as input. Currently we have three domains, 'deepzono',      		'refinezono' and 'deeppoly'.
         
@@ -84,13 +102,14 @@ class ERAN:
             execute_list, output_info = self.optimizer.get_deepzono(nn,specLB, specUB)
             analyzer = Analyzer(execute_list, nn, domain, timeout_lp, timeout_milp, output_constraints, use_default_heuristic,label, prop, testing)
         elif domain == 'deeppoly' or domain == 'refinepoly':
-            execute_list, output_info = self.optimizer.get_deeppoly(nn, specLB, specUB, lexpr_weights, lexpr_cst, lexpr_dim, uexpr_weights, uexpr_cst, uexpr_dim, expr_size, spatial_constraints)
+            execute_list, output_info = self.optimizer.get_deeppoly(nn, specLB, specUB, lexpr_weights, lexpr_cst, lexpr_dim, uexpr_weights, uexpr_cst, uexpr_dim, expr_size, spatial_constraints, W=W, b=b)
             analyzer = Analyzer(execute_list, nn, domain, timeout_lp, timeout_milp, output_constraints, use_default_heuristic, label, prop, testing)
-        dominant_class, nlb, nub, failed_labels, x = analyzer.analyze()
+        dominant_class, nlb, nub, failed_labels, x, hps = analyzer.analyze()
+        self.hps = hps
         if testing:
-            return dominant_class, nn, nlb, nub, output_info
+            return dominant_class, nn, nlb, nub, output_info, hps
         else:
-            return dominant_class, nn, nlb, nub, failed_labels, x
+            return dominant_class, nn, nlb, nub, failed_labels, x, hps
 
 
     def analyze_zonotope(self, zonotope, domain, timeout_lp, timeout_milp, use_default_heuristic, output_constraints=None, testing = False):
diff --git a/tf_verify/geometric_diff.py b/tf_verify/geometric_diff.py
new file mode 100644
index 0000000..e815c7b
--- /dev/null
+++ b/tf_verify/geometric_diff.py
@@ -0,0 +1,230 @@
+# From https://github.com/kevinzakka/spatial-transformer-network
+import tensorflow as tf
+
+
+def spatial_transformer_network(input_fmap, theta, out_dims=None, **kwargs):
+    """
+    Spatial Transformer Network layer implementation as described in [1].
+
+    The layer is composed of 3 elements:
+
+    - localization_net: takes the original image as input and outputs
+      the parameters of the affine transformation that should be applied
+      to the input image.
+
+    - affine_grid_generator: generates a grid of (x,y) coordinates that
+      correspond to a set of points where the input should be sampled
+      to produce the transformed output.
+
+    - bilinear_sampler: takes as input the original image and the grid
+      and produces the output image using bilinear interpolation.
+
+    Input
+    -----
+    - input_fmap: output of the previous layer. Can be input if spatial
+      transformer layer is at the beginning of architecture. Should be
+      a tensor of shape (B, H, W, C).
+
+    - theta: affine transform tensor of shape (B, 6). Permits cropping,
+      translation and isotropic scaling. Initialize to identity matrix.
+      It is the output of the localization network.
+
+    Returns
+    -------
+    - out_fmap: transformed input feature map. Tensor of size (B, H, W, C).
+
+    Notes
+    -----
+    [1]: 'Spatial Transformer Networks', Jaderberg et. al,
+         (https://arxiv.org/abs/1506.02025)
+
+    """
+    # grab input dimensions
+    B = tf.shape(input_fmap)[0]
+    H = tf.shape(input_fmap)[1]
+    W = tf.shape(input_fmap)[2]
+
+    # reshape theta to (B, 2, 3)
+    theta = tf.reshape(theta, [B, 2, 3])
+
+    # generate grids of same size or upsample/downsample if specified
+    if out_dims:
+        out_H = out_dims[0]
+        out_W = out_dims[1]
+        batch_grids = affine_grid_generator(out_H, out_W, theta)
+    else:
+        batch_grids = affine_grid_generator(H, W, theta)
+
+    x_s = batch_grids[:, 0, :, :]
+    y_s = batch_grids[:, 1, :, :]
+
+    # sample input with grid to get output
+    out_fmap = bilinear_sampler(input_fmap, x_s, y_s)
+
+    return out_fmap
+
+
+def get_pixel_value(img, x, y):
+    """
+    Utility function to get pixel value for coordinate
+    vectors x and y from a  4D tensor image.
+
+    Input
+    -----
+    - img: tensor of shape (B, H, W, C)
+    - x: flattened tensor of shape (B*H*W,)
+    - y: flattened tensor of shape (B*H*W,)
+
+    Returns
+    -------
+    - output: tensor of shape (B, H, W, C)
+    """
+    shape = tf.shape(x)
+    batch_size = shape[0]
+    height = shape[1]
+    width = shape[2]
+
+    batch_idx = tf.range(0, batch_size)
+    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
+    b = tf.tile(batch_idx, (1, height, width))
+
+    indices = tf.stack([b, y, x], 3)
+
+    return tf.gather_nd(img, indices)
+
+
+def affine_grid_generator(height, width, theta):
+    """
+    This function returns a sampling grid, which when
+    used with the bilinear sampler on the input feature
+    map, will create an output feature map that is an
+    affine transformation [1] of the input feature map.
+
+    Input
+    -----
+    - height: desired height of grid/output. Used
+      to downsample or upsample.
+
+    - width: desired width of grid/output. Used
+      to downsample or upsample.
+
+    - theta: affine transform matrices of shape (num_batch, 2, 3).
+      For each image in the batch, we have 6 theta parameters of
+      the form (2x3) that define the affine transformation T.
+
+    Returns
+    -------
+    - normalized grid (-1, 1) of shape (num_batch, 2, H, W).
+      The 2nd dimension has 2 components: (x, y) which are the
+      sampling points of the original image for each point in the
+      target image.
+
+    Note
+    ----
+    [1]: the affine transformation allows cropping, translation,
+         and isotropic scaling.
+    """
+    num_batch = tf.shape(theta)[0]
+
+    # create normalized 2D grid
+    x = tf.linspace(-1.0, 1.0, width)
+    y = tf.linspace(-1.0, 1.0, height)
+    x_t, y_t = tf.meshgrid(x, y)
+
+    # flatten
+    x_t_flat = tf.reshape(x_t, [-1])
+    y_t_flat = tf.reshape(y_t, [-1])
+
+    # reshape to [x_t, y_t , 1] - (homogeneous form)
+    ones = tf.ones_like(x_t_flat)
+    sampling_grid = tf.stack([x_t_flat, y_t_flat, ones])
+
+    # repeat grid num_batch times
+    sampling_grid = tf.expand_dims(sampling_grid, axis=0)
+    sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))
+
+    # cast to float32 (required for matmul)
+    theta = tf.cast(theta, 'float32')
+    sampling_grid = tf.cast(sampling_grid, 'float32')
+
+    # transform the sampling grid - batch multiply
+    batch_grids = tf.matmul(theta, sampling_grid)
+    # batch grid has shape (num_batch, 2, H*W)
+
+    # reshape to (num_batch, H, W, 2)
+    batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])
+
+    return batch_grids
+
+
+def bilinear_sampler(img, x, y):
+    """
+    Performs bilinear sampling of the input images according to the
+    normalized coordinates provided by the sampling grid. Note that
+    the sampling is done identically for each channel of the input.
+
+    To test if the function works properly, output image should be
+    identical to input image when theta is initialized to identity
+    transform.
+
+    Input
+    -----
+    - img: batch of images in (B, H, W, C) layout.
+    - grid: x, y which is the output of affine_grid_generator.
+
+    Returns
+    -------
+    - out: interpolated images according to grids. Same size as grid.
+    """
+    H = tf.shape(img)[1]
+    W = tf.shape(img)[2]
+    max_y = tf.cast(H - 1, 'int32')
+    max_x = tf.cast(W - 1, 'int32')
+    zero = tf.zeros([], dtype='int32')
+
+    # rescale x and y to [0, W-1/H-1]
+    x = tf.cast(x, 'float32')
+    y = tf.cast(y, 'float32')
+    x = 0.5 * ((x + 1.0) * tf.cast(max_x, 'float32'))
+    y = 0.5 * ((y + 1.0) * tf.cast(max_y, 'float32'))
+
+    # grab 4 nearest corner points for each (x_i, y_i)
+    x0 = tf.cast(tf.floor(x), 'int32')
+    x1 = x0 + 1
+    y0 = tf.cast(tf.floor(y), 'int32')
+    y1 = y0 + 1
+
+    # clip to range [0, H-1/W-1] to not violate img boundaries
+    x0 = tf.clip_by_value(x0, zero, max_x)
+    x1 = tf.clip_by_value(x1, zero, max_x)
+    y0 = tf.clip_by_value(y0, zero, max_y)
+    y1 = tf.clip_by_value(y1, zero, max_y)
+
+    # get pixel value at corner coords
+    Ia = get_pixel_value(img, x0, y0)
+    Ib = get_pixel_value(img, x0, y1)
+    Ic = get_pixel_value(img, x1, y0)
+    Id = get_pixel_value(img, x1, y1)
+
+    # recast as float for delta calculation
+    x0 = tf.cast(x0, 'float32')
+    x1 = tf.cast(x1, 'float32')
+    y0 = tf.cast(y0, 'float32')
+    y1 = tf.cast(y1, 'float32')
+
+    # calculate deltas
+    wa = (x1-x) * (y1-y)
+    wb = (x1-x) * (y-y0)
+    wc = (x-x0) * (y1-y)
+    wd = (x-x0) * (y-y0)
+
+    # add dimension for addition
+    wa = tf.expand_dims(wa, axis=3)
+    wb = tf.expand_dims(wb, axis=3)
+    wc = tf.expand_dims(wc, axis=3)
+    wd = tf.expand_dims(wd, axis=3)
+
+    # compute output
+    out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
+
+    return out
diff --git a/tf_verify/geometric_symadex.py b/tf_verify/geometric_symadex.py
new file mode 100644
index 0000000..2240681
--- /dev/null
+++ b/tf_verify/geometric_symadex.py
@@ -0,0 +1,941 @@
+import numpy as np
+import time
+from geometric_constraints import *
+from multiprocessing import Pool
+from gurobipy import *
+import math
+import tensorflow as tf
+from tensorflow.contrib import graph_editor as ge
+from geometric_diff import spatial_transformer_network
+
+def normalize(image, means, stds, dataset):
+    # normalization taken out of the network
+    if len(means) == len(image):
+        for i in range(len(image)):
+            image[i] -= means[i]
+            if stds!=None:
+                image[i] /= stds[i]
+    elif dataset == 'mnist'  or dataset == 'fashion':
+        for i in range(len(image)):
+            image[i] = (image[i] - means[0])/stds[0]
+    elif(dataset=='cifar10'):
+        count = 0
+        tmp = np.zeros(3072)
+        for i in range(1024):
+            tmp[count] = (image[count] - means[0])/stds[0]
+            count = count + 1
+            tmp[count] = (image[count] - means[1])/stds[1]
+            count = count + 1
+            tmp[count] = (image[count] - means[2])/stds[2]
+            count = count + 1
+
+        for i in range(3072):
+            image[i] = tmp[i]
+
+def normalize_poly(num_params, lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights, uexpr_dim, means, stds, dataset):
+    # normalization taken out of the network
+    if dataset == 'mnist' or dataset == 'fashion':
+        for i in range(len(lexpr_cst)):
+            lexpr_cst[i] = (lexpr_cst[i] - means[0]) / stds[0]
+            uexpr_cst[i] = (uexpr_cst[i] - means[0]) / stds[0]
+        for i in range(len(lexpr_weights)):
+            lexpr_weights[i] /= stds[0]
+            uexpr_weights[i] /= stds[0]
+    else:
+        for i in range(len(lexpr_cst)):
+            lexpr_cst[i] = (lexpr_cst[i] - means[i % 3]) / stds[i % 3]
+            uexpr_cst[i] = (uexpr_cst[i] - means[i % 3]) / stds[i % 3]
+        for i in range(len(lexpr_weights)):
+            lexpr_weights[i] /= stds[(i // num_params) % 3]
+            uexpr_weights[i] /= stds[(i // num_params) % 3]
+
+
+def denormalize(image, means, stds, dataset):
+    if dataset == 'mnist'  or dataset == 'fashion':
+        for i in range(len(image)):
+            image[i] = image[i]*stds[0] + means[0]
+    elif(dataset=='cifar10'):
+        count = 0
+        tmp = np.zeros(3072)
+        for i in range(1024):
+            tmp[count] = image[count]*stds[0] + means[0]
+            count = count + 1
+            tmp[count] = image[count]*stds[1] + means[1]
+            count = count + 1
+            tmp[count] = image[count]*stds[2] + means[2]
+            count = count + 1
+
+        for i in range(3072):
+            image[i] = tmp[i]
+
+def get_grid( transform_attack_container, config, eran, means, stds, lb, ub, target, img, splits ):
+    set_grid( transform_attack_container, splits )
+    set_transform_attack_for(transform_attack_container, img, False, True, config.debug)
+    images = get_grid_images(transform_attack_container)
+
+    values = np.array(images)
+    attacks_lb = values[:, ::2]
+    idx = []
+    for i, img in enumerate( attacks_lb ):
+        normalize(img, means, stds, config.dataset)
+        predict_label, _ = eran.quick_eval(img)
+           
+        i_cp = i 
+        x = np.zeros( splits.shape[0] )
+        for j in range( splits.shape[0] ):
+            x[j] = ( ub[j] - lb[j] ) * ( i_cp % splits[j] ) / splits[j]
+            x[j] += lb[j]
+            i_cp /= splits[j]
+        
+        idx.append( (predict_label == target, x) )
+
+    set_grid( transform_attack_container, np.ones( splits.shape[0] ) )
+    return idx
+
+def generate_splits(lb, ub, splits):
+    args_mesh = []
+    for j in range( lb.shape[0] ):
+        args_mesh.append( np.linspace( lb[j], ub[j], num=splits[j], endpoint=False ) )
+    out = np.meshgrid(*args_mesh)
+    lbs = np.stack( out ).reshape( len( splits), -1 )
+    deltas = (ub - lb) / splits
+    deltas = deltas.reshape(-1,1)
+    ubs = lbs + deltas
+    return lbs, ubs
+
+def verify_geometric_box(eran, config, means, stds, transform_attack_container, img, given_label, lb, ub, n_splits, W=None, b=None, update_geom=False, target=float('nan')):
+    t_begin = time.time()
+    update_geom_box(transform_attack_container, lb, ub, n_splits) 
+    if update_geom:
+        for i in range( W.shape[0] ):
+            add_geom_hp(transform_attack_container, W[i], b[i])
+    set_transform_attack_for(transform_attack_container, img, False, False, config.debug, target=target)
+    lines = get_transformations(transform_attack_container)
+    attack_lens = get_pixel_lentghts(transform_attack_container); 
+
+    if config.dataset == 'mnist' or config.dataset == 'fashion':
+        n_rows, n_cols, n_channels = 28, 28, 1
+    else:
+        n_rows, n_cols, n_channels = 32, 32, 3
+    dim = n_rows * n_cols * n_channels
+    k = config.num_params + 1 + 1 + dim
+
+    print('Number of splits', n_splits, 'Number of lines: ', len(lines) , flush=True)
+    assert len(lines) % k == 0
+
+    spec_lb = np.zeros(config.num_params + dim)
+    spec_ub = np.zeros(config.num_params + dim)
+
+    expr_size = config.num_params
+    lexpr_cst, uexpr_cst = [], []
+    lexpr_weights, uexpr_weights = [], []
+    lexpr_dim, uexpr_dim = [], []
+
+    ver_chunks_box, ver_chunks_poly, tot_chunks = 0, 0, 0
+    results_chunks = [] 
+    params_chunks = [] 
+    for i, line in enumerate(lines):
+        if i % k < config.num_params:
+            # read specs for the parameters
+            values = line
+            assert len(values) == 2
+            param_idx = i % k
+            spec_lb[dim + param_idx] = values[0]
+            spec_ub[dim + param_idx] = values[1]
+            print('parameter %d: [%.4f, %.4f]' % (param_idx, values[0], values[1]), flush=True)
+        elif i % k == config.num_params:
+            # read interval bounds for image pixels
+            values = line
+            spec_lb[:dim] = values[::2]
+            spec_ub[:dim] = values[1::2]
+            # if config.debug:
+            #     show_ascii_spec(spec_lb, spec_ub)
+        elif i % k < k - 1:
+            # read polyhedra constraints for image pixels
+            tokens = line
+            assert len(tokens) == 2 + 2 * config.num_params
+
+            bias_lower, weights_lower = tokens[0], tokens[1:1 + config.num_params]
+            bias_upper, weights_upper = tokens[config.num_params + 1], tokens[2 + config.num_params:]
+
+            assert len(weights_lower) == config.num_params
+            assert len(weights_upper) == config.num_params
+
+            lexpr_cst.append(bias_lower)
+            uexpr_cst.append(bias_upper)
+            for j in range(config.num_params):
+                lexpr_dim.append(dim + j)
+                uexpr_dim.append(dim + j)
+                lexpr_weights.append(weights_lower[j])
+                uexpr_weights.append(weights_upper[j])
+        else:
+            assert (len(line) == 0)
+            for p_idx in range(config.num_params):
+                lexpr_cst.append(spec_lb[dim + p_idx])
+                for l in range(config.num_params):
+                    lexpr_weights.append(0)
+                    lexpr_dim.append(dim + l)
+                uexpr_cst.append(spec_ub[dim + p_idx])
+                for l in range(config.num_params):
+                    uexpr_weights.append(0)
+                    uexpr_dim.append(dim + l)
+            params_chunks.append( (spec_lb[dim:].copy(),spec_ub[dim:].copy()) )
+            normalize(spec_lb[:dim], means, stds, config.dataset)
+            normalize(spec_ub[:dim], means, stds, config.dataset)
+            normalize_poly(config.num_params, lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,
+                           uexpr_dim, means, stds, config.dataset)
+            if not math.isnan( target ):
+                return lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,uexpr_dim
+            if config.debug:
+                print('Running the analysis...', flush=True)
+
+            perturbed_label_poly, _, _, _, _, _, hps_deeppoly = eran.analyze_box(
+                spec_lb, spec_ub, 'deeppoly',
+                config.timeout_lp, config.timeout_milp, config.use_default_heuristic, None,
+                lexpr_weights, lexpr_cst, lexpr_dim,
+                uexpr_weights, uexpr_cst, uexpr_dim,
+                expr_size, label=given_label, W=W, b=b)
+            perturbed_label_box, _, _, _, _, _, hps_box = eran.analyze_box(
+                spec_lb[:dim], spec_ub[:dim], 'deeppoly',
+                config.timeout_lp, config.timeout_milp, config.use_default_heuristic, label=given_label)
+            t_end = time.time()
+
+            print('DeepG: ', perturbed_label_poly, '\tInterval: ', perturbed_label_box, '\tlabel: ', given_label,
+                  '[Time: %.4f]' % (t_end - t_begin), flush=True)
+
+            tot_chunks += 1
+            if perturbed_label_box != given_label:
+                ok_box = False
+            else:
+                ver_chunks_box += 1
+
+            if perturbed_label_poly != given_label:
+                ok_poly = False
+            else:
+                ver_chunks_poly += 1
+
+            results_chunks.append( perturbed_label_poly != -1 or perturbed_label_box != -1 )
+            #lexpr_cst, uexpr_cst = [], []
+            #lexpr_weights, uexpr_weights = [], []
+            #lexpr_dim, uexpr_dim = [], []
+
+    if tot_chunks != 1:
+        return results_chunks, params_chunks
+
+    return ver_chunks_box, ver_chunks_poly, tot_chunks, (lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,uexpr_dim, attack_lens), hps_box, hps_deeppoly
+
+def gen_tf_underapprox_box( num_params ):
+    params_lb = tf.Variable( np.zeros(  num_params, dtype=np.float64 ), trainable=True)
+    params_ub = tf.Variable( np.zeros(  num_params, dtype=np.float64 ), trainable=True)
+    
+    lb_old = tf.placeholder( tf.float64, ( num_params ) )
+    ub_old = tf.placeholder( tf.float64, ( num_params ) )
+    project_lb = tf.clip_by_value( tf.clip_by_value( params_lb, lb_old, params_ub ), lb_old, ub_old ) 
+    project_ub = tf.clip_by_value( params_ub, project_lb, ub_old )
+
+    tf_hp = tf.placeholder( tf.float64, ( num_params ) )
+    tf_bound = tf.placeholder( tf.float64, () )
+    hp_val = tf.reduce_sum( project_ub * tf.nn.relu( tf_hp ) ) + tf.reduce_sum( project_lb * -tf.nn.relu( -tf_hp ) )
+    project_hp = tf.where( hp_val * tf_bound > 0 , tf_bound / hp_val, 1.0 )
+    project_hp_ub_try = tf.clip_by_value(  project_ub * project_hp,  project_lb, ub_old )
+    project_hp_lb_try = tf.clip_by_value(  project_lb * project_hp, lb_old, project_ub )
+    project_hp_ub = tf.where( tf_hp > 0, project_hp_ub_try, project_ub ) 
+    project_hp_lb = tf.where( tf_hp < 0, project_hp_lb_try, project_lb )
+    
+    init_lb = tf.assign( params_lb, lb_old )
+    init_ub = tf.assign( params_ub, ub_old )
+    obj = -tf.reduce_sum( tf.log( tf.clip_by_value( project_hp_ub - project_hp_lb, 1e-100, 1e12 ) ) )
+    lr = tf.placeholder(tf.float64,())
+    opt_step = tf.train.GradientDescentOptimizer( lr ).minimize( obj, var_list=[ params_lb, params_ub ] )
+    
+    return lr, opt_step, init_lb, init_ub, project_hp_lb, project_hp_ub, tf_hp, tf_bound, lb_old, ub_old
+
+def create_underapprox_box_lp(eran, sess, tf_tensors, config, means, stds, transform_attack_container, img, lb, ub, y_tar, its):
+    st = time.time()
+    
+    model = Model()
+    model.setParam( 'OutputFlag', 0 )
+    lbs = model.addVars( len(lb), lb=lb, ub=ub )
+    lbs = [ lbs[k] for k in lbs ] 
+    ubs = model.addVars( len(lb), lb=lb, ub=ub )
+    ubs = [ ubs[k] for k in ubs ] 
+    for i in range( len(lb) ):
+        model.addConstr( lbs[i] + 1e-3, GRB.LESS_EQUAL, ubs[i] ) 
+    obj_coefs = np.ones( len(lb) ) 
+    #obj_coefs = 1/(ub-lb)
+    model.setObjective( ubs @ obj_coefs - lbs @ obj_coefs, GRB.MAXIMIZE ) 
+    model.update()
+
+    learning_rate, optim_step, init_lb, init_ub, project_hp_lb, project_hp_ub, tf_hp, tf_bound, lb_old, ub_old = tf_tensors
+    for it in range(its):
+        out = verify_geometric_box(eran, config, means, stds, transform_attack_container, img, y_tar, lb, ub, 1)
+        deeppoly_hps = out[-1][0]
+        deeppoly_hps = filter(lambda a : a[-2] > 0, deeppoly_hps) # Only problematic hps
+        deeppoly_hps = map(lambda a : ((a[0][0][-len(lb):],a[0][1]), -a[-2]), deeppoly_hps)  # Extract HPs from full deeppoly hps, reverse bound
+        deeppoly_hps = list( deeppoly_hps )
+
+        if len( deeppoly_hps ) == 0: 
+            end = time.time()
+            print( 'Verified underapprox box:', end - st, 's', flush=True )
+            return lb, ub
+
+        
+        worst_obj = None
+        worst_idx = -1
+        for i, el in enumerate(deeppoly_hps):
+            obj_val = el[1]
+            if worst_obj is None or worst_obj > obj_val:
+                worst_obj = obj_val
+                worst_idx = i
+
+        #tar_bound = worst_obj * 0.9
+        alpha = 0.75
+        best_bound = np.maximum( deeppoly_hps[worst_idx][0][0], 0 ) @ lb + np.minimum( deeppoly_hps[worst_idx][0][0], 0 ) @ ub +  deeppoly_hps[worst_idx][0][1]
+        
+        #if np.abs( worst_obj ) < 0.1:
+        #    alpha = 0.9
+
+        tar_bound = np.minimum( worst_obj * alpha + -best_bound * (1-alpha), 1e-5 )
+        #if np.abs( worst_obj ) < 0.1 and best_bound < -1e-7:
+        #    tar_bound = 1e-7
+        
+
+        print( 'It:', it, 'Bound:', worst_obj, 'Target:', tar_bound, flush=True )
+        copy = model.copy()
+        '''
+        norm = np.max( np.abs( deeppoly_hps[worst_idx][0][0] ) )
+        ws =  deeppoly_hps[worst_idx][0][0] / norm
+        sq = ws * ws
+        ab = np.abs( ws )
+        coef =( ab @ (ub-lb) )/ ( sq @ (ub-lb) ) * 0.8
+        if coef > 1:
+            coef = 1
+        lb_new = np.maximum( -ws, 0 ) * coef * lb + (1 - coef * np.maximum( -ws, 0 )) * ub
+        lb_new =  np.maximum( np.sign( ws ), 0 ) * lb + np.maximum( -np.sign( ws ), 0 ) * lb_new
+        ub_new = np.maximum( ws, 0 ) * coef * ub + (1 - coef * np.maximum( ws, 0 )) * lb
+        ub_new =  np.maximum( -np.sign( ws ), 0 ) * ub + np.maximum( np.sign( ws ), 0 ) * ub_new
+        lb = lb_new
+        ub = ub_new
+        assert np.all( lb_new <= ub ) and np.all( ub_new <= ub ) and np.all( lb_new >= lb ) and np.all( ub_new >= lb ) and np.all( lb_new <= ub_new )
+        '''
+        best_vol = -1
+        best_ub = best_lb = None
+        inits = np.random.uniform( 0, 0.5, size=(500,lb.shape[0]) ) 
+        for init in inits:
+            sess.run( init_lb, feed_dict={lb_old: (1-init)*lb+init*ub} )
+            sess.run( init_ub, feed_dict={ub_old: init*lb+(1-init)*ub} )
+            for i in range(50):
+                sess.run( [optim_step], feed_dict={learning_rate:5e-5, tf_hp: deeppoly_hps[worst_idx][0][0], tf_bound: -deeppoly_hps[worst_idx][0][1] - tar_bound, lb_old: lb, ub_old: ub} )
+                lb_new, ub_new = sess.run( (project_hp_lb, project_hp_ub), feed_dict={ tf_hp: deeppoly_hps[worst_idx][0][0], tf_bound: -deeppoly_hps[worst_idx][0][1] - tar_bound, lb_old: lb, ub_old: ub} )
+                if not ( np.all( lb <= lb_new) and np.all( lb_new <= ub_new) and np.all( ub_new <= ub ) ):
+                    import pdb; pdb.set_trace()
+                vol = np.prod( ub_new - lb_new )
+                
+                forcebound = np.minimum( worst_obj * 0.98 + -best_bound * 0.02, 0 )
+                calcbound = np.maximum( deeppoly_hps[worst_idx][0][0], 0 ) @ ub_new + np.minimum( deeppoly_hps[worst_idx][0][0], 0 ) @ lb_new +  deeppoly_hps[worst_idx][0][1]
+                calcbound *= -1
+
+                if vol > best_vol and calcbound > forcebound and np.all(ub_new - lb_new) > 1e-5:
+                    best_vol = vol
+                    best_ub = ub_new
+                    best_lb = lb_new
+                
+        if best_lb is None:
+            return None, None
+        
+        lb = np.array( best_lb )
+        ub = np.array( best_ub )
+ 
+        '''
+        lin = LinExpr()
+        for i in range( len(lb) ):
+            if deeppoly_hps[worst_idx][0][0][i] > 0: 
+                lin += deeppoly_hps[worst_idx][0][0][i] * copy.getVarByName( ubs[i].VarName )
+            else:
+                lin += deeppoly_hps[worst_idx][0][0][i] * copy.getVarByName( lbs[i].VarName )
+        
+        copy.addConstr( lin + deeppoly_hps[worst_idx][0][1], GRB.LESS_EQUAL, -tar_bound )
+        copy.optimize()
+        constr += deeppoly_hps[worst_idx][0][1]
+        if copy.Status != 2:
+            break
+        lb = [ copy.getVarByName( lbs[i].VarName ).X for i in range( len(lb) ) ]
+        ub = [ copy.getVarByName( ubs[i].VarName ).X for i in range( len(lb) ) ]
+
+        for i in range( len(lb) ):
+            lbs[i].LB = lb[i]
+            ubs[i].LB = lb[i]
+            lbs[i].UB = ub[i]
+            ubs[i].UB = ub[i]
+
+        lb = np.array( lb )
+        ub = np.array( ub )
+        '''
+        #TODO remove
+        if np.any(ub - lb < 1e-6):
+            print( 'Problem:', ub - lb < 1e-6, flush=True )
+            break
+
+        model.update()
+    return None, None
+
+def pool_func_deeppoly( idx ):
+    thread_model = global_model.copy()
+    input_size = global_eq[ 0 ].shape[ 1 ]
+    xs = [ thread_model.getVarByName( 'x' + str( i ) ) for i in range( input_size ) ]
+    obj = global_eq[ 0 ][ idx, : ] @ xs
+    thread_model.reset()
+    thread_model.setObjective( obj, GRB.MAXIMIZE )
+    thread_model.optimize()
+    assert thread_model.SolCount == 1
+    ub = thread_model.objbound + global_eq[ 1 ][ idx, 0 ]
+    return ub
+
+def get_hp_bias_under( W, b, lb_under, ub_under, W_model, b_model, lb_model, ub_model, checkZero=True, poly_c=0 ):
+    #    obj += global_eq[ 0 ][ idx, p ] * xs[ p ]
+    W = W.copy().T
+    b = b.copy().reshape( -1, 1 )
+    # contains underapprox 
+    dev = ( ub_under - lb_under ) / 2.0
+    mid = ( ub_under + lb_under ) / 2.0
+    dev = dev.reshape(-1,1)
+    mid = mid.reshape(-1,1)
+
+    lp_cost_1 = np.sum( np.abs( W * dev ), axis=0, keepdims=True ).T
+    lp_cost_2 = b + W.T @ mid
+    lp_cost_under = ( lp_cost_1 + lp_cost_2 ).reshape( -1, 1 )
+    if checkZero:
+        b -= np.maximum( lp_cost_under + 1e-13, 0 )
+    else:
+        b -= lp_cost_under + 1e-13
+
+    if poly_c > 1e-6:
+        out = [W.T.copy(), b.copy()]
+        
+        model = Model()
+        model.setParam('OutputFlag', 0)
+        xs = [ model.addVar(lb=lb_model[i], ub=ub_model[i], vtype=GRB.CONTINUOUS, name='x' + str(i) ) for i in range( W.shape[0] ) ]
+        constrs = W_model @ xs
+        for i, c in enumerate( constrs ):
+            model.addConstr( c, GRB.LESS_EQUAL, -b_model[i] )
+        model.optimize()
+        if model.SolCount != 1:
+            import pdb; pdb.set_trace()
+
+        global global_model, global_eq
+        global_model = model
+        global_eq = out
+        ncpus = os.sysconf("SC_NPROCESSORS_ONLN")
+        with Pool(ncpus) as pool:
+            solver_result = pool.map( pool_func_deeppoly, list( range ( W.shape[1] ) ) )
+        del globals()[ 'global_model' ]
+        del globals()[ 'global_eq' ]
+        bounds = np.array( [p for p in solver_result] ).reshape(-1,1)
+        b -= poly_c*np.maximum( bounds + 1e-13, 0 )
+
+    return ( W.T, b )
+
+def simplify_region( Ws, bs, lb, ub ):
+    model = Model()
+    xs = []
+    for i in range( len(lb) ):
+        xs.append( model.addVar(lb=lb[i], ub=ub[i], vtype=GRB.CONTINUOUS, name='x%i'%i) )
+    
+    constrs = Ws@xs
+    for i in range( len( constrs ) ):
+        model.addConstr( constrs[i] <= -bs[i] )
+    model.update()
+    model.setParam('OutputFlag', 0)
+    model.setParam( 'DualReductions', 0 )
+
+    tried_names = []
+    j = 0
+    dels = 0
+    while True:
+        c = None
+        for cs in model.getConstrs():
+            if not cs.ConstrName in tried_names:
+                c = cs
+                break
+        if c is None:
+            break
+        mc = model.copy()
+        c = mc.getConstrByName( c.ConstrName )
+        le = mc.getRow(c)
+        rhs_orig = c.RHS
+        if c.Sense == '>':
+            c.RHS -= 1
+            mc.setObjective( le, GRB.MINIMIZE )
+        elif c.Sense == '<':
+            c.RHS += 1
+            mc.setObjective( le, GRB.MAXIMIZE )
+        else:
+            assert False
+        mc.update()
+        mc.reset()
+        mc.optimize()
+        if c.Sense == '<':
+            if not mc.Status == 2:
+                import pdb; pdb.set_trace()
+            if mc.objval <= rhs_orig + 1e-10:
+                model.remove(model.getConstrByName(c.ConstrName))
+                model.update()
+                dels += 1
+            else:
+                tried_names.append( c.ConstrName )
+        else:
+            if mc.objval >= rhs_orig:
+                model.remove(model.getConstrByName(c.ConstrName))
+                model.update()
+                dels += 1
+            else:
+                tried_names.append( c.ConstrName )
+        del mc
+        j += 1
+        #if j % 10 == 0:
+        #    print( dels, '/', j )
+
+    c_idx = [ int(c.ConstrName[1:]) for c in model.getConstrs() ]
+    return xs, model, Ws[c_idx, :], bs[c_idx]
+
+def create_geom_poly(eran, tf_underbox_tens, config, means, stds, transform_attack_container, img, lb, ub, y_tar, its):
+    
+    if config.dataset == 'mnist' or config.dataset == 'fashion':
+        n_rows, n_cols, n_channels = 28, 28, 1
+    else:
+        n_rows, n_cols, n_channels = 32, 32, 3
+    num_pixels = n_rows * n_cols * n_channels
+
+    lb_under, ub_under = create_underapprox_box_lp(eran, eran.tf_session, tf_underbox_tens, config, means, stds, transform_attack_container, img, lb, ub, y_tar, its)
+    return lb_under, ub_under
+    #np.savez( 'geometric', lb_under=lb_under, ub_under=ub_under ) 
+    
+    #f = np.load( 'geometric.npz', allow_pickle=True)
+    #lb_under, ub_under = f['lb_under'], f['ub_under'] 
+    
+    W = np.zeros( (0, lb_under.shape[0]) ) 
+    b = np.zeros( (0, 1) ) 
+    #poly_c = 0.0
+    poly_c = 0.9
+    for i in range(100):
+        out = verify_geometric_box(eran, config, means, stds, transform_attack_container, img, y_tar, lb, ub, 1, W=W, b=b, update_geom=True)
+        grid_idx = get_grid(  transform_attack_container, config, eran, means, stds, lb, ub, y_tar, img, np.array([20,20,1]) )
+        bound = min( map(lambda a : -a[-2], out[-1][0]) )
+        print( '\n\n\nBound: ', bound, "It:", i,"\n\n\n" , flush=True )
+
+        if bound > -0.1 or i > 5:
+            poly_c = 0
+            import pdb; pdb.set_trace()
+        else: 
+            poly_c *= 0.9
+        if bound > 0: 
+            return W, b, lb_under, ub_under
+
+        # Out hps
+        hps_out_layer = list( filter( lambda a : a[-2] > 0, out[-1][0] ) )
+        hps_out_layer_weights = list( map(lambda a : a[0][0][-len(lb):].reshape(1,-1), hps_out_layer) )
+        hps_out_layer_biases = list( map(lambda a : a[0][1].reshape(1,-1), hps_out_layer) )
+        hps_out_layer_weights = np.concatenate( hps_out_layer_weights, axis=0 )  
+        hps_out_layer_biases = np.concatenate( hps_out_layer_biases, axis=0 )  
+        hps_out_layer = hps_out_layer_weights, hps_out_layer_biases
+        
+        # Interim hps
+        hps_interim_layers_lbs = list( map(lambda a : a[1][:,-config.num_params-1:], out[-1][1] ))
+        hps_interim_layers_lbs = np.concatenate( hps_interim_layers_lbs, axis=0 )
+        hps_interim_layers_ubs = list( map(lambda a : a[1][:,-config.num_params-1:], out[-1][2] ))
+        hps_interim_layers_ubs = np.concatenate( hps_interim_layers_ubs, axis=0 )
+        hps_interim_layers = np.concatenate( (hps_interim_layers_lbs, hps_interim_layers_ubs), axis=0 )
+        hps_interim_layers = hps_interim_layers[:,:-1], hps_interim_layers[:,-1]
+        
+        # Input hps
+        lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,uexpr_dim, attack_lens = out[-3]
+        target = np.max( attack_lens ) * poly_c + (1-poly_c) * np.min( attack_lens ) 
+        lexpr_cst, lexpr_weights, lexpr_dim, uexpr_cst, uexpr_weights,uexpr_dim = verify_geometric_box(eran, config, means, stds, transform_attack_container, img, y_tar, lb, ub, 1, W=W, b=b, update_geom=True, target=target)
+        lexpr_cst, lexpr_weights = -np.array(lexpr_cst), -np.array(lexpr_weights)
+        uexpr_cst, uexpr_weights = np.array(uexpr_cst), np.array(uexpr_weights)
+        lexpr_weights = (lexpr_weights.reshape( -1, config.num_params ))[:num_pixels]
+        uexpr_weights = (uexpr_weights.reshape( -1, config.num_params ))[:num_pixels]
+        lexpr_cst = lexpr_cst[:num_pixels]
+        uexpr_cst = uexpr_cst[:num_pixels]
+        hps_input = np.concatenate( (lexpr_weights, uexpr_weights), axis=0 ), np.concatenate( (lexpr_cst, uexpr_cst) )
+        samples = np.random.uniform( lb, ub, size=(10,lb.shape[0]) ).T
+        good_hp = []
+        for idx, v in enumerate( hps_input[0] @ samples ):
+            if ( np.unique( v ).shape[0] != 1 ):
+                good_hp.append( idx )
+        good_hp = np.array( good_hp )
+        hps_input = hps_input[0][good_hp], hps_input[1][good_hp]
+       
+        # Filter biases
+        hps_ws = np.concatenate( ( hps_out_layer[0], hps_interim_layers[0] ), axis=0)
+        hps_bs = np.concatenate( ( hps_out_layer[1][:,0], hps_interim_layers[1] ), axis=0)
+        hps_input = get_hp_bias_under( hps_input[0], hps_input[1], lb_under, ub_under, W, b, lb, ub, checkZero=True )
+        hps_rest = get_hp_bias_under( hps_ws, hps_bs, lb_under, ub_under, W, b, lb, ub, checkZero=True, poly_c=poly_c )
+        hps_ws = np.concatenate( ( hps_input[0], hps_rest[0] ), axis=0 )
+        hps_bs = np.concatenate( ( hps_input[1], hps_rest[1] ), axis=0 )
+        
+        norms = np.linalg.norm( hps_ws, axis=1 ).reshape( -1, 1 )
+        hps_ws /= norms
+        hps_bs /= norms
+
+        good_hp = np.logical_and( np.abs( hps_ws[:,0]/hps_ws[:,1] ) > 0.1 , np.abs( hps_ws[:,1]/hps_ws[:,0] ) > 0.1 )
+        print( '\n\n\n\nNon-zero hps:', np.sum( good_hp ),  '/', good_hp.shape[0] , flush=True )
+        if poly_c > 0 :
+            hps_ws = hps_ws[ good_hp ] 
+            hps_bs = hps_bs[ good_hp ] 
+
+        # Final
+        W = np.concatenate( ( W, hps_ws ), axis=0 )
+        b = np.concatenate( ( b, hps_bs ), axis=0 )
+        xs_new, model_new, W, b = simplify_region( W, b, lb, ub )
+        del model_new, xs_new
+       
+        w_show = np.concatenate( ( W[:,:-1], b.reshape(-1,1) ), axis=1 ) 
+        draw2d_region( w_show, lb[:-1], ub[:-1], lb_under[:-1], ub_under[:-1], 'geom_it' + str(i), (lb[:-1], ub[:-1]), draw_hp=W.shape[0], grid_idx=grid_idx )
+
+def draw2d_region_boxes( lb, ub, res, param, name, bounds, grid_idx ):
+    import matplotlib.pyplot as plt
+    plt.figure()
+    axes = plt.gca()
+    bounds_lb, bounds_ub = bounds
+    margin = ( np.array( bounds_ub ) - bounds_lb ) / 20.0
+    axes.set_xlim([ bounds_lb[ 0 ] - margin[ 0 ], bounds_ub[ 0 ] + margin[ 0 ] ])
+    axes.set_ylim([ bounds_lb[ 1 ] - margin[ 1 ], bounds_ub[ 1 ] + margin[ 1 ] ])
+    plt.plot([lb[0],lb[0], ub[0], ub[0],lb[0]], [lb[1],ub[1], ub[1], lb[1],lb[1]], 'b--')
+
+    if not grid_idx is None:
+        for succ, pt in grid_idx:
+            if succ:
+                plt.plot( pt[0], pt[1], 'gx' )
+            else:
+                plt.plot( pt[0], pt[1], 'kx' )
+
+    for r,p in zip( res, param ):
+        lb_p, ub_p = p
+        if r:
+            plt.plot([lb_p[0],lb_p[0], ub_p[0], ub_p[0],lb_p[0]], [lb_p[1],ub_p[1], ub_p[1], lb_p[1],lb_p[1]], 'g-')
+        else:
+            plt.plot([lb_p[0],lb_p[0], ub_p[0], ub_p[0],lb_p[0]], [lb_p[1],ub_p[1], ub_p[1], lb_p[1],lb_p[1]], 'r-')
+
+
+    plt.savefig( name )
+    plt.close()
+ 
+
+def draw2d_region( W, lb, ub, lb_sh, ub_sh, name, bounds, draw_hp=None, grid_idx=None ):
+    lbs = np.concatenate( ( -np.eye( lb.shape[ 0 ] ), lb[:,np.newaxis] ), axis=1 ) 
+    ubs = np.concatenate( ( np.eye( ub.shape[ 0 ] ), -ub[:,np.newaxis] ), axis=1 )
+    W_full = np.concatenate( ( lbs, ubs ), axis=0 )
+    W_full = np.concatenate( ( W_full, W ), axis=0 )
+    verts = []
+
+
+    for i in range( W_full.shape[0]):
+        for j in range( i + 1, W_full.shape[0]):
+            if j == i + lbs.shape[0] and i < lbs.shape[0]:
+                continue
+            idx = np.array( [i,j] )
+            try:
+                x = np.linalg.solve( W_full[idx, :-1], -W_full[idx, -1] )
+            except np.linalg.LinAlgError as e:
+                continue
+            m = np.matmul( W_full, x.tolist() + [1] )
+            if np.any( m > 1e-6 ):
+                continue
+            #print( i ,j , m, x )
+            verts.append( x )
+    verts = np.array( verts )
+    import matplotlib.pyplot as plt
+    from scipy.spatial import ConvexHull, convex_hull_plot_2d
+    try:
+        hull = ConvexHull(verts)
+        hs = hull.simplices
+        vol = hull.volume
+    except:
+
+        dist_0 = np.max( verts[ :, 0 ] - np.average( verts[ :, 0 ] ) ) 
+        dist_1 = np.max( verts[ :, 1 ] - np.average( verts[ :, 1 ] ) ) 
+        if dist_0 < dist_1:
+            hs = []
+            hs.append( ( np.argmin( verts[ :, 1 ] ), np.argmax( verts[ :, 1 ] ) ) )
+        else:
+            hs = []
+            hs.append( ( np.argmin( verts[ :, 0 ] ), np.argmax( verts[ :, 0 ] ) ) )
+        vol = 0
+    plt.figure()
+    plt.plot(verts[:,0], verts[:,1], 'ko')
+    plt.plot([lb_sh[0],lb_sh[0], ub_sh[0], ub_sh[0]], [lb_sh[1],ub_sh[1], ub_sh[1], lb_sh[1]], 'bo')
+    axes = plt.gca()
+    bounds_lb, bounds_ub = bounds
+    margin = ( np.array( bounds_ub ) - bounds_lb ) / 20.0
+    axes.set_xlim([ bounds_lb[ 0 ] - margin[ 0 ], bounds_ub[ 0 ] + margin[ 0 ] ])
+    axes.set_ylim([ bounds_lb[ 1 ] - margin[ 1 ], bounds_ub[ 1 ] + margin[ 1 ] ])
+
+
+    for simplex in hs:
+        plt.plot(verts[simplex, 0], verts[simplex, 1], 'k-')
+    if not draw_hp is None:
+        ls = W[ : draw_hp, : ]
+        ls_pt = []
+        for l in ls:
+            y_min = ( l[ 0 ] * ( bounds_lb[ 0 ] - margin[ 0 ] ) + l[ 2 ] ) / -l[ 1 ]
+            y_max = ( l[ 0 ] * ( bounds_ub[ 0 ] + margin[ 0 ] ) + l[ 2 ] ) / -l[ 1 ]
+            plt.plot([bounds_lb[ 0 ] - margin[ 0 ], bounds_ub[ 0 ] + margin[ 0 ]], [y_min,y_max], 'r-')
+            x_min = ( l[ 1 ] * ( bounds_lb[ 1 ] - margin[ 1 ] ) + l[ 2 ] ) / -l[ 0 ]
+            x_max = ( l[ 1 ] * ( bounds_ub[ 1 ] + margin[ 1 ] ) + l[ 2 ] ) / -l[ 0 ]
+            plt.plot([x_min,x_max], [bounds_lb[ 1 ] - margin[ 1 ], bounds_ub[ 1 ] + margin[ 1 ]], 'r-')
+    if not grid_idx is None:
+        for succ, pt in grid_idx:
+            if succ:
+                plt.plot( pt[0], pt[1], 'gx' )
+            else:
+                plt.plot( pt[0], pt[1], 'kx' )
+    plt.plot([lb_sh[0],lb_sh[0], ub_sh[0], ub_sh[0],lb_sh[0]], [lb_sh[1],ub_sh[1], ub_sh[1], lb_sh[1],lb_sh[1]], 'b-')
+    plt.plot([lb[0],lb[0], ub[0], ub[0],lb[0]], [lb[1],ub[1], ub[1], lb[1],lb[1]], 'b--')
+    plt.title( 'Volume:' + str(vol) )
+    plt.savefig( name )
+    plt.close()
+    return vol
+
+def exp_over_transform( sess, model, config, transform_attack_container, x_orig, lb_param, ub_param, it, batch_size  ):
+    #https://arxiv.org/pdf/1707.07397.pdf
+    tf_y = sess.graph.get_operation_by_name(model.op.name).outputs[0]
+    tf_x = sess.graph.get_operations()[0].outputs[0]
+
+    if config.dataset == 'mnist' or config.dataset == 'fashion':
+        n_rows, n_cols, n_channels = 28, 28, 1
+    else:
+        n_rows, n_cols, n_channels = 32, 32, 3
+
+    import tensorflow as tf
+    import tensorflow_addons as tfa
+    
+    tf_img = tf.Variable(np.zeros( (1, n_rows, n_cols, n_channels), dtype=np.float32 ), trainable=True)
+    tf_img_orig = tf.placeholder(tf.float32, [1, n_rows, n_cols, n_channels] )
+    tf_imgs = tf.repeat( tf_img, batch_size, axis=0 ) 
+    tf_imgs_orig = tf.repeat( tf_img_orig, batch_size, axis=0 ) 
+    tf_params = tf.placeholder(tf.float32, (batch_size, 3) )
+    rot = ( tfa.image.rotate, [ tf_params[:,0], "BILINEAR" ] )
+    scale = ( tfa.image.transform, [ 1.0/tf_params[:,1], "BILINEAR" ] )
+    shear = ( tfa.image.shear_x, [ tf_params[:,2] ] )
+
+    ops = [ rot, scale, shear ]
+    for op in ops:
+        tf_imgs = op[0]( tf_imgs, *op[1] )
+        tf_imgs_orig = op[0]( tf_imgs_orig, *op[1] )
+    
+    geom_obj = tf.norm( tf_imgs_orig - tf_imgs, axis=0 )
+
+
+    import pdb; pdb.set_trace()
+
+    out = sess.run( geom_obj, feed_dict={ tf_img: x_orig+np.random.uniform( -0.01, 0.01, size=x_orig.shape), tf_img_orig: x_orig } )
+    x_orig2 = get_attack_by_params( transform_attack_container, np.zeros(lb_param.shape[0]) )
+
+
+    batch = np.random.uniform( lb_param, ub_param, size=( lb_param.shape[0], batch_size ) )
+    imgs = []
+    for p in batch:
+        vals = get_attack_by_params(transform_attack_container, p)
+        vals_lb = vals[::2]
+        vals_ub = vals[1::2]
+        vals = (vals_lb + vals_ub)/2.0
+        imgs.append( vals )
+    
+def create_EoT_tensors( eran, config, means, stds ):
+    #https://arxiv.org/pdf/1707.07397.pdf
+    num_params = config.num_params
+    sess = eran.tf_session
+    model = eran.model
+
+    tf_y = sess.graph.get_operation_by_name(model.op.name).outputs[0]
+    tf_x = sess.graph.get_operations()[0].outputs[0]
+    
+    if config.dataset == 'mnist' or config.dataset == 'fashion':
+        n_rows, n_cols, n_channels = 28, 28, 1
+    else:
+        n_rows, n_cols, n_channels = 32, 32, 3
+
+    
+    def tf_image_translate(x_size, y_size, tx, ty):
+        zeros = tf.zeros( tf.shape( ty )[0], dtype=tf.float32 )
+        ones = tf.ones( tf.shape( tx )[0], dtype=tf.float32 )
+        transforms = [ones, zeros, tx, zeros, ones, ty, zeros, zeros]
+        flat = tf.concat( [ transforms, [ones] ], 0 )
+        m = tf.reshape( tf.transpose(flat), (-1,3,3) )
+        return m  
+    
+    def tf_image_scale( x_size, y_size, s ):
+        tx = (x_size - 1) / 2.0 * ( 1 - s )
+        ty = (y_size - 1) / 2.0 * ( 1 - s )
+
+        zeros = tf.zeros( tf.shape( s )[0], dtype=tf.float32 )
+        ones = tf.ones( tf.shape( s )[0], dtype=tf.float32 )
+ 
+        transforms = [s, zeros, zeros, zeros, s, zeros, zeros, zeros]
+        flat = tf.concat( [ transforms, [ones] ], 0 )
+        m = tf.reshape( tf.transpose(flat), (-1,3,3) )
+        return m  
+
+    def tf_image_shear( x_size, y_size, s ):
+        ty = (y_size - 1) / 2.0
+        zeros = tf.zeros( tf.shape( s )[0], dtype=tf.float32 )
+        ones = tf.ones( tf.shape( s )[0], dtype=tf.float32 )
+        transforms = [ones, s, zeros, zeros, ones, zeros, zeros, zeros]
+        flat = tf.concat( [ transforms, [ones] ], 0 )
+        m = tf.reshape( tf.transpose(flat), (-1,3,3) )
+        return m  
+
+    def tf_image_rotate( x_size, y_size, phi ):
+        tx = (x_size - 1) / 2.0 
+        ty = (y_size - 1) / 2.0
+        sin = tf.sin( phi )
+        cos = tf.cos( phi )
+        tx_final = tx * ( 1 - cos ) + ty * sin 
+        ty_final = ty * ( 1 - cos ) - tx * sin
+        zeros = tf.zeros( tf.shape( phi )[0], dtype=tf.float32 )
+        ones = tf.ones( ( tf.shape( phi )[0]), dtype=tf.float32 )
+        transforms = [cos, -sin, zeros, sin, cos, zeros, zeros, zeros]
+        flat = tf.concat( [ transforms, [ones] ], 0 )
+        m = tf.reshape( tf.transpose(flat), (-1,3,3) )
+        return m  
+
+    def normalize_tf(dataset, means, stds, img):
+        if dataset == 'mnist'  or dataset == 'fashion':
+            return (img - means[0])/stds[0]
+        else:
+            assert( False, "TODO implement tf cifar normalization " )
+
+    lb_old = tf.placeholder( tf.float32, ( 3 ) )
+    ub_old = tf.placeholder( tf.float32, ( 3 ) )
+ 
+    tf_img_orig = tf.placeholder(tf.float32, [1, n_rows, n_cols, n_channels] )
+    
+    label_placeholder = tf.placeholder(tf.int64,(None))
+    labels = tf.one_hot( label_placeholder, 10 )   
+
+    # PGD x'
+    params_lb_init_pl = tf.placeholder( tf.float32, (3) )
+    params_lb = tf.Variable( np.zeros( (3), dtype=np.float32 ), trainable=True)
+    params_lb_init = tf.assign( params_lb, params_lb_init_pl )
+
+    params_ub_init_pl = tf.placeholder( tf.float32, (3) )
+    params_ub = tf.Variable( np.zeros( (3), dtype=np.float32 ), trainable=True)
+    params_ub_init = tf.assign( params_ub, params_ub_init_pl )
+
+    tf_batch_size = tf.placeholder( tf.int64, () )
+    eps = tf.random.uniform( [ tf_batch_size, 3 ], minval=0, maxval=1 )
+    params = eps * ( params_ub - params_lb ) + params_lb
+
+    projected_lbs = tf.clip_by_value( tf.clip_by_value( params_lb, lb_old,  ub_old ), lb_old, params_ub )
+    projected_lbs = tf.assign( params_lb, projected_lbs )
+    
+    projected_ubs = tf.clip_by_value( tf.clip_by_value( params_ub, lb_old,  ub_old ), params_lb, ub_old )
+    projected_ubs = tf.assign( params_ub, projected_ubs )
+    
+    m_rot = tf_image_rotate( n_rows, n_cols, params[:, 2] )
+    m_scale = tf_image_scale( n_rows, n_cols, params[:, 1] )
+    m_shear = tf_image_shear( n_rows, n_cols, params[:, 0] )
+    m = m_rot @ m_scale @ m_shear
+    m_vec = tf.reshape( m, [-1,9] )[:,0:6]
+    imgs = tf.tile( tf_img_orig, [tf.shape( params )[0],1,1,1] )
+    tr = spatial_transformer_network( imgs, m_vec ) 
+    tr_norm = normalize_tf(config.dataset, means, stds, tr)
+    tr_norm = tf.cast( tr_norm, tf.float64 )
+    tr_norm = tf.reshape( tr_norm, [tf.shape( tr_norm )[0],-1] )
+    tf_logits = ge.graph_replace( tf_y , {tf_x: tr_norm} )
+    average_loss = tf.reduce_sum( tf.nn.softmax_cross_entropy_with_logits( logits=tf_logits, labels=labels ) ) / tf.cast( tf_batch_size, tf.float64 ) 
+    predicted_labels = tf.argmax( tf_logits, axis=1 )
+    predicted_labels_sum = tf.reduce_sum( tf.cast( tf.equal( predicted_labels, label_placeholder ), dtype=tf.int32 ) )
+    
+    lam = tf.placeholder( tf.float32, () )
+    loss2 = tf.cast( lam * tf.reduce_sum( tf.log( tf.maximum( params_ub - params_lb, 1e-5 ) ) ), tf.float64 )
+    vol = tf.reduce_prod( params_ub - params_lb )
+    loss = average_loss - loss2
+
+    learning_rate = tf.placeholder(tf.float32,())
+    optim_step = tf.train.GradientDescentOptimizer( learning_rate ).minimize( loss, var_list=[ params_lb, params_ub ] )#, params_logvars ] )
+
+    return tf_x, tf_y, optim_step, average_loss, loss2, predicted_labels_sum, vol, lam, tf_batch_size, learning_rate, label_placeholder, tf_img_orig, params_lb_init, params_ub_init, params_lb_init_pl, params_ub_init_pl, projected_lbs, projected_ubs, params_lb, params_ub, lb_old, ub_old
+ 
+def exp_over_transform( tf_tensors, sess, model, config, transform_attack_container, x_orig, lbbox, ubbox, means, stds, it, batch_size, target, lam_np ):
+    if config.dataset == 'mnist' or config.dataset == 'fashion':
+        n_rows, n_cols, n_channels = 28, 28, 1
+    else:
+        n_rows, n_cols, n_channels = 32, 32, 3
+
+    x_orig = x_orig.reshape( 1, n_rows, n_cols, n_channels )
+    tf_x, tf_y, optim_step, average_loss, loss2, predicted_labels_sum, vol, lam, tf_batch_size, learning_rate, label_placeholder, tf_img_orig, params_lb_init, params_ub_init, params_lb_init_pl, params_ub_init_pl, projected_lbs, projected_ubs, params_lb, params_ub, lb_old, ub_old = tf_tensors
+
+    sess.run( params_lb_init, feed_dict={ params_lb_init_pl: 0.6 * lbbox + 0.4 * ubbox } )
+    sess.run( params_ub_init, feed_dict={ params_ub_init_pl: 0.4 * lbbox + 0.6 * ubbox } )
+    
+    labels = [target] * batch_size
+    avg_loss = 0
+    avg_pred_label = 0
+    avg_vol_loss = 0
+    avg_vol = 0
+    for i in range(it):
+        _, loss_transform, loss_vol, pred_lab, vol_np = sess.run( [ optim_step, average_loss, loss2, predicted_labels_sum,vol ], feed_dict={ lam: lam_np, tf_batch_size: batch_size, learning_rate: 5e-4, label_placeholder: labels, tf_img_orig: x_orig } )
+        #sess.run( projected_logvar, feed_dict={ upper_logvar: (ubbox - lbbox)/2.1 } )
+        sess.run( projected_lbs, feed_dict={lb_old: lbbox, ub_old: ubbox} )
+        sess.run( projected_ubs, feed_dict={lb_old: lbbox, ub_old: ubbox} )
+        avg_loss +=  loss_transform
+        avg_pred_label += pred_lab
+        avg_vol_loss += loss_vol 
+        avg_vol += vol_np 
+        if(i+1)%10==0:
+            print('Step: %d, Loss Transform: %g, Vol_loss: %g Percent: %g, Vol %g' % (i+1,avg_loss/10.0, avg_vol_loss/10.0, avg_pred_label/10.0, avg_vol/10.0))
+            avg_loss = 0
+            avg_pred_label = 0
+            avg_vol_loss = 0
+            avg_vol = 0
+    lb_params, ub_params = sess.run( ( params_lb, params_ub ) )
+    #best_params, labels_tf, params = sess.run( [ params_means, predicted_labels, params ], feed_dict={ tf_batch_size: 500, tf_img_orig: x_orig, params_logvars: np.log(0.3*(ubbox-lbbox)) } )
+    params = np.random.uniform( lb_params, ub_params, size=[500,lb_params.shape[0]] )
+    labels_mislav = [] 
+    for i in range( params.shape[0] ):
+        x_p = get_attack_by_params( transform_attack_container, params[i].astype(np.float64) )[1::2]
+        normalize( x_p, means, stds, config.dataset )
+        tar_x = np.argmax( sess.run( tf_y, feed_dict={tf_x: x_p} ) ) 
+        labels_mislav.append( tar_x )
+    percent = np.sum( np.array( labels_mislav ) == target ) / 5.0
+    print( 'Percent: ', percent )
+    return lb_params, ub_params, percent
+
+def get_eot_lam( eran, tf_eot_tens, config, transform_attack_container, image, lbbox, ubbox, means, stds, cl, lam_st=0.8, num_its=100, percent=97, batch_size=100 ):
+    lam = lam_st
+    bad = False
+    while True:
+        print( 'lambda:', lam ) 
+        lb_eot, ub_eot, perc_eot = exp_over_transform( tf_eot_tens, eran.tf_session, eran.model, config, transform_attack_container, image, lbbox, ubbox, means, stds, num_its, batch_size, cl , lam)
+        lam /= 2.0
+        if perc_eot >= percent:
+            break
+        if lam < 1e-2:
+            bad = True
+            break
+    if bad:
+        vol_eot = 0.0
+        return 0.0, None, None, None, vol_eot
+
+    lam_end = lam * 4.0 
+    lam_begin = lam * 2.0
+    print( 'Lam Start:', lam_begin, 'Lam End:', lam_end ) 
+    while True:
+        lam_mid = (lam_begin + lam_end) / 2.0
+        print( 'lambda:', lam_mid ) 
+        lb_eot, ub_eot, perc_eot = exp_over_transform( tf_eot_tens, eran.tf_session, eran.model, config, transform_attack_container, image, lbbox, ubbox, means, stds, num_its, batch_size, cl , lam_mid)
+        if perc_eot >= percent:
+            lam_begin = lam_mid
+        else:
+            lam_end = lam_mid
+        if lam_end - lam_begin < 1e-2:
+            break
+    print( 'Lam Final:', lam_mid )
+    
+    lb_eot = lb_eot.astype( np.float64 ) 
+    ub_eot = ub_eot.astype( np.float64 )
+    vol_eot =  np.prod( ub_eot - lb_eot)
+
+    return lam_mid, lb_eot, ub_eot, perc_eot, vol_eot
diff --git a/tf_verify/optimizer.py b/tf_verify/optimizer.py
index 0b05798..6539433 100644
--- a/tf_verify/optimizer.py
+++ b/tf_verify/optimizer.py
@@ -179,12 +179,15 @@ class Optimizer:
                     execute_list.append(DeeppolyPoolNode(image_shape, window_size, strides, pad_top, pad_left, input_names, output_name, output_shape, is_maxpool))
                 i += 1
             elif self.operations[i] == "Relu":
-                #self.resources[i][domain].append(refine)
+                if 'W' in dir( self ):
+                    inp = (*self.resources[i][domain], self.W, self.b)
+                else:
+                    inp = self.resources[i][domain]
                 nn.layertypes.append('ReLU')
                 if domain == 'deepzono':
                     execute_list.append(DeepzonoRelu(*self.resources[i][domain]))
                 else:
-                    execute_list.append(DeeppolyReluNode(*self.resources[i][domain]))
+                    execute_list.append(DeeppolyReluNode(*inp))
                 nn.numlayer += 1
                 i += 1
             elif self.operations[i] == "Sigmoid":
@@ -361,7 +364,7 @@ class Optimizer:
 
 
 
-    def get_deeppoly(self, nn, specLB, specUB, lexpr_weights, lexpr_cst, lexpr_dim, uexpr_weights, uexpr_cst, uexpr_dim, expr_size, spatial_constraints=None):
+    def get_deeppoly(self, nn, specLB, specUB, lexpr_weights, lexpr_cst, lexpr_dim, uexpr_weights, uexpr_cst, uexpr_dim, expr_size, spatial_constraints=None, W=None, b=None):
         """
         This function will go through self.operations and self.resources and create a list of Deeppoly-Nodes which then can be run by an Analyzer object.
         It is assumed that self.resources[i]['deeppoly'] holds the resources for an operation of type self.operations[i].
@@ -385,12 +388,14 @@ class Optimizer:
         """
         execute_list = []
         output_info = []
+        self.W = W
+        self.b = b
         domain = 'deeppoly'
         assert self.operations[0] == "Placeholder", "the optimizer for Deeppoly cannot handle this network "
         input_names, output_name, output_shape = self.resources[0][domain]
         output_info.append(self.resources[0][domain][-2:])
         execute_list.append(DeeppolyInput(specLB, specUB, input_names, output_name, output_shape,
-                                            lexpr_weights, lexpr_cst, lexpr_dim, uexpr_weights, uexpr_cst, uexpr_dim, expr_size, spatial_constraints))
+                                            lexpr_weights, lexpr_cst, lexpr_dim, uexpr_weights, uexpr_cst, uexpr_dim, expr_size, spatial_constraints, W=W, b=b))
 
         self.get_abstract_element(nn, 1, execute_list, output_info, 'deeppoly')
         self.set_predecessors(nn, execute_list)
diff --git a/tf_verify/pgd_div.py b/tf_verify/pgd_div.py
new file mode 100644
index 0000000..f80ae8a
--- /dev/null
+++ b/tf_verify/pgd_div.py
@@ -0,0 +1,57 @@
+import tensorflow as tf
+import numpy as np
+from tensorflow.contrib import graph_editor as ge
+
+# https://arxiv.org/pdf/2003.06878.pdf
+def create_pgd_graph( lb, ub, sess, tf_input, tf_output, target):
+    # Replace graph
+    tf_image = tf.Variable(lb, trainable=True)
+    tf_output = ge.graph_replace(tf_output, {tf_input: tf_image + 0.0})
+
+    # Output diversification
+    tf_dir = tf.placeholder( shape=(tf_output.shape[1]), dtype=tf.float64 )
+    tf_eps_init = tf.placeholder( shape=lb.shape, dtype=tf.float64 )
+    tf_init_error = tf.reduce_sum( tf_dir * tf_output )
+    tf_init_grad = tf.gradients( tf_init_error, [tf_image] )[0]
+    tf_train_init = tf_image + tf_eps_init * tf.sign( tf_init_grad ) 
+    tf_train_init = tf.assign( tf_image, tf_train_init )
+   
+    # PGD
+    tf_train_error = tf.keras.utils.to_categorical( target, num_classes=tf_output.shape[-1] )
+    tf_eps_pgd = tf.placeholder( shape=lb.shape, dtype=tf.float64 )
+    tf_train_error = tf.keras.losses.categorical_crossentropy( tf_train_error, tf_output, from_logits=True)
+    tf_train_grad = tf.gradients( tf_train_error, [tf_image] )[0]
+    tf_train_pgd = tf_image - tf_eps_pgd * tf.sign( tf_train_grad ) 
+    tf_train_pgd = tf.assign( tf_image, tf_train_pgd )
+    
+    # Clip
+    tf_train_clip = tf.clip_by_value( tf_image, lb, ub ) 
+    tf_train_clip = tf.assign( tf_image, tf_train_clip )
+
+    # Seed
+    tf_seed_pl = tf.placeholder( shape=lb.shape, dtype=tf.float64 )
+    tf_seed = tf.assign( tf_image, tf_seed_pl )
+
+    return tf_image, tf_dir, tf_seed_pl, tf_eps_init, tf_eps_pgd, tf_output, tf_train_init, tf_train_pgd, tf_train_clip, tf_seed
+
+def pgd(sess, lb, ub, 
+        tf_image, tf_dir, tf_seed_pl, tf_eps_init, tf_eps_pgd, 
+        tf_output, tf_train_init, tf_train_pgd, tf_train_clip, tf_seed, 
+        eps_init, eps_pgd, odi_its, pgd_its):
+
+    seed = np.random.uniform( lb, ub, size=lb.shape )
+    d = np.random.uniform( -1, 1, size=(tf_output.shape[1]) )
+
+    sess.run( tf_seed, feed_dict={ tf_seed_pl: seed } )
+    for i in range(odi_its):
+        sess.run( tf_train_init, feed_dict={tf_dir : d, tf_eps_init : eps_init} )
+        sess.run( tf_train_clip )
+    seed = sess.run( tf_image )
+
+    sess.run( tf_seed, feed_dict={ tf_seed_pl: seed } )
+    for i in range(pgd_its):
+        sess.run( tf_train_pgd, feed_dict={tf_eps_pgd : eps_pgd} )
+        sess.run( tf_train_clip )
+    seed = sess.run( tf_image )
+
+    return seed
diff --git a/tf_verify/read_net_file.py b/tf_verify/read_net_file.py
index 2a7cc79..8dcb9b6 100644
--- a/tf_verify/read_net_file.py
+++ b/tf_verify/read_net_file.py
@@ -57,7 +57,7 @@ def extract_std(text):
     return std_array
 
 def numel(x):
-    return product([int(i) for i in x.shape])
+    return product([int(i) for i in x.shape[1:]])
 
 def parseVec(net):
     return np.array(eval(net.readline()[:-1]))
@@ -88,7 +88,9 @@ def read_tensorflow_net(net_file, in_len, is_trained_with_pytorch):
     mean = 0.0
     std = 0.0
     net = open(net_file,'r')
-    x = tf.placeholder(tf.float64, [in_len], name = "x")
+    x = tf.placeholder(tf.float64, [None], name = "x")
+    x = tf.reshape(x, [-1,in_len], name = "imgs")
+    batch_size =tf.shape( x )[0]
     y = None
     z1 = None
     z2 = None
@@ -133,13 +135,13 @@ def read_tensorflow_net(net_file, in_len, is_trained_with_pytorch):
             #b = myConst(b.reshape([1, numel(b)]))
             b = myConst(b)
             if(curr_line=="Affine"):
-                x = tf.nn.bias_add(tf.matmul(tf.reshape(x, [1, numel(x)]),W), b)
+                x = tf.nn.bias_add(tf.matmul(tf.reshape(x, [batch_size, numel(x)]),W), b)
             elif(curr_line=="ReLU"):
-                x = tf.nn.relu(tf.nn.bias_add(tf.matmul(tf.reshape(x, [1, numel(x)]),W), b))
+                x = tf.nn.relu(tf.nn.bias_add(tf.matmul(tf.reshape(x, [batch_size, numel(x)]),W), b))
             elif(curr_line=="Sigmoid"):
-                x = tf.nn.sigmoid(tf.nn.bias_add(tf.matmul(tf.reshape(x, [1, numel(x)]),W), b))
+                x = tf.nn.sigmoid(tf.nn.bias_add(tf.matmul(tf.reshape(x, [batch_size, numel(x)]),W), b))
             else:
-                x = tf.nn.tanh(tf.nn.bias_add(tf.matmul(tf.reshape(x, [1, numel(x)]),W), b))
+                x = tf.nn.tanh(tf.nn.bias_add(tf.matmul(tf.reshape(x, [batch_size, numel(x)]),W), b))
             print("\tOutShape: ", x.shape)
             print("\tWShape: ", W.shape)
             print("\tBShape: ", b.shape)
@@ -161,7 +163,7 @@ def read_tensorflow_net(net_file, in_len, is_trained_with_pytorch):
             ksize =  [1] + args['pool_size'] + [1]
             print("MaxPool", args)
 
-            x = tf.nn.max_pool(tf.reshape(x, [1] + args["input_shape"]), padding=padding_arg, strides=stride, ksize=ksize)
+            x = tf.nn.max_pool(tf.reshape(x, [batch_size] + args["input_shape"]), padding=padding_arg, strides=stride, ksize=ksize)
             print("\tOutShape: ", x.shape)
         elif curr_line == "Conv2D":
             is_conv = True
@@ -199,10 +201,10 @@ def read_tensorflow_net(net_file, in_len, is_trained_with_pytorch):
             else:
                 stride_arg = [1,1,1,1]
 
-            x = tf.nn.conv2d(tf.reshape(x, [1] + args["input_shape"]), filter=W, strides=stride_arg, padding=padding_arg)
+            x = tf.nn.conv2d(tf.reshape(x, [batch_size] + args["input_shape"]), filter=W, strides=stride_arg, padding=padding_arg)
 
             b = myConst(parseVec(net))
-            h, w, c = [int(i) for i in x.shape ][1:]
+            h, w, c = [ int(i) for i in x.shape[1:] ]
             print("Conv2D", args, "W.shape:",W.shape, "b.shape:", b.shape)
             print("\tOutShape: ", x.shape)
             if("ReLU" in line):
diff --git a/tf_verify/refine_activation.py b/tf_verify/refine_activation.py
index ffdc605..0de72fe 100644
--- a/tf_verify/refine_activation.py
+++ b/tf_verify/refine_activation.py
@@ -16,6 +16,7 @@
 
 
 import numpy as np
+from gurobipy import *
 from zonoml import *
 from elina_interval import *
 from elina_abstract0 import *
@@ -50,7 +51,36 @@ def update_activation_expr_bounds(man, element, layerno, lower_bound_expr, upper
             uexpr = np.ascontiguousarray(uexpr, dtype=np.double)
             #update_activation_upper_bound_for_neuron(man, element, layerno, var, uexpr, varsid, k)
 
-def refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_groups, timeout_lp, timeout_milp, use_default_heuristic, domain):
+def pool_func_deeppoly( idx ):
+    thread_model = global_model.copy()
+    input_size = global_nllb.shape[1] - 1 
+    xs = [ thread_model.getVarByName( 'x' + str( i ) ) for i in range( input_size ) ]
+    obj = global_nllb[ idx, : -1] @ xs
+    thread_model.reset()
+    thread_model.setObjective( obj, GRB.MINIMIZE )
+    thread_model.optimize()
+    if thread_model.SolCount != 1:
+        import pdb; pdb.set_trace()
+    assert thread_model.SolCount == 1
+    lb = thread_model.objbound + global_nllb[ idx, -1 ]
+    
+    '''bad_exam_lb = []
+    for p in range( input_size ) :
+        bad_exam_lb.append( xs[p].x )
+    bad_exam_lb = np.array( bad_exam_lb )'''
+    
+    obj = global_nlub[ idx, : -1] @ xs
+    thread_model.reset()
+    thread_model.setObjective( obj, GRB.MAXIMIZE )
+    thread_model.optimize()
+    if thread_model.SolCount != 1:
+        import pdb; pdb.set_trace()
+    assert thread_model.SolCount == 1
+    ub = thread_model.objbound + global_nlub[ idx, -1 ]
+    return lb, ub
+
+
+def refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_groups, timeout_lp, timeout_milp, use_default_heuristic, domain, W=None, b=None):
     """
     refines the relu transformer
 
@@ -88,7 +118,7 @@ def refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_
             first_FC = i
             break
             
-    if nn.activation_counter==0:
+    if nn.activation_counter==0 and W is None:
         if domain=='deepzono':
             encode_kactivation_cons(nn, man, element, offset, predecessor_index, length, lbi, ubi, relu_groups, False, 'refinezono', nn.layertypes[layerno])
             if nn.layertypes[layerno]=='ReLU':
@@ -117,16 +147,64 @@ def refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_
             use_milp = 0
             timeout = timeout_lp
         use_milp = use_milp and config.use_milp
+        if not W is None:
+            use_milp = False
         candidate_vars = []
+        indices = []
         for i in range(length):
-            if((lbi[i]<0 and ubi[i]>0) or (lbi[i]>0)):
+            if((lbi[i]<0 and ubi[i]>0) or (lbi[i]>0 and W is None)):
                  candidate_vars.append(i)
         #TODO handle residual layers here
-        if config.refine_neurons==True:
-            resl, resu, indices = get_bounds_for_layer_with_milp(nn, nn.specLB, nn.specUB, predecessor_index, predecessor_index, length, nlb, nub, relu_groups, use_milp,  candidate_vars, timeout)
+
+        if not W is None and len( candidate_vars ) > 0:
+            np_cv = np.array( candidate_vars )
+            llb = np.zeros( (np_cv.shape[0], len( nn.specLB ) + 1 ))
+            lub = np.zeros( (np_cv.shape[0], len( nn.specLB ) + 1 ))
+            get_linear_bounds(man, element, np_cv, llb, lub, np_cv.shape[0], predecessor_index)
+            
+            model = Model()
+            model.setParam("OutputFlag",0)
+            num_params = W.shape[1]
+            num_pixels = len( nn.specLB ) - num_params
+            xs = [model.addVar( nn.specLB[num_pixels+i], nn.specUB[num_pixels+i], name='x'+str(i)) for i in range( num_params )]
+            constrs = W @ xs 
+            for i, constr in enumerate( constrs ):
+                model.addConstr( constr, GRB.LESS_EQUAL, -b[i] ) 
+            llb = -llb[:, num_pixels :]
+            lub = lub[:, num_pixels :]
+            model.update()
+            global global_model, global_nllb, global_nlub
+            global_model = model
+            global_nllb = llb
+            global_nlub = lub
+            ncpus = os.sysconf("SC_NPROCESSORS_ONLN")
+            with multiprocessing.Pool(ncpus) as pool:
+                solver_result = pool.map( pool_func_deeppoly, list( range ( np_cv.shape[0] ) ) )
+            del globals()[ 'global_model' ]
+            del globals()[ 'global_nllb' ]
+            del globals()[ 'global_nlub' ]
+            lbi, ubi = zip( *solver_result )
+            lbi_old = np.array( nlb[predecessor_index] )
+            ubi_old = np.array( nub[predecessor_index] )
+            indices = np.logical_or( lbi > lbi_old[np_cv], ubi < ubi_old[np_cv] )
+            indices = np_cv[ indices ]
+            lbi_old[np_cv] = np.maximum( lbi, lbi_old[np_cv] )
+            ubi_old[np_cv] = np.minimum( ubi, ubi_old[np_cv] )
+            nlb[predecessor_index] = lbi_old.tolist()
+            nub[predecessor_index] = ubi_old.tolist()
+            for cv_idx in np_cv.tolist():
+                update_bounds_for_neuron(man, element,  predecessor_index, cv_idx, lbi_old[cv_idx], ubi_old[cv_idx])
+
+            new_count = 0
+            for i in range(length):
+                if(nlb[predecessor_index][i]<0 and nub[predecessor_index][i]>0):
+                    new_count += 1
+            print( 'Layer: ', predecessor_index, ' Neurons refined: ', np_cv.shape[0], ' New count: ', new_count ) 
+ 
+        elif config.refine_neurons==True:
+            resl, resu, indices = get_bounds_for_layer_with_milp(nn, nn.specLB, nn.specUB, predecessor_index, predecessor_index, length, nlb, nub, relu_groups, use_milp,  candidate_vars, timeout, W, b)
             nlb[predecessor_index] = resl
             nub[predecessor_index] = resu
-
         lbi = nlb[predecessor_index]
         ubi = nub[predecessor_index]
             
@@ -148,17 +226,19 @@ def refine_activation_with_solver_bounds(nn, self, man, element, nlb, nub, relu_
                 element = relu_zono_layerwise(man,True,element,offset, length, use_default_heuristic)
                 return element
         else:
-            if config.refine_neurons==True:
+            if config.refine_neurons==True or not W is None:
                 for j in indices:
-                    update_bounds_for_neuron(man,element,predecessor_index,j,resl[j],resu[j])
-            lower_bound_expr, upper_bound_expr = encode_kactivation_cons(nn, man, element, offset, predecessor_index, length, lbi, ubi, relu_groups, False, 'refinepoly', nn.layertypes[layerno])
+                    update_bounds_for_neuron(man,element,predecessor_index,j,nlb[predecessor_index][j],nub[predecessor_index][j])
+            if W is None:
+                lower_bound_expr, upper_bound_expr = encode_kactivation_cons(nn, man, element, offset, predecessor_index, length, lbi, ubi, relu_groups, False, 'refinepoly', nn.layertypes[layerno])
             if nn.layertypes[layerno] == 'ReLU':
                 handle_relu_layer(*self.get_arguments(man, element), use_default_heuristic)
             elif nn.layertypes[layerno] == 'Sigmoid':
                 handle_sigmoid_layer(*self.get_arguments(man, element))
             else:
                 handle_tanh_layer(*self.get_arguments(man, element))
-            update_activation_expr_bounds(man, element, layerno, lower_bound_expr, upper_bound_expr, lbi, ubi)
+            if W is None:
+                update_activation_expr_bounds(man, element, layerno, lower_bound_expr, upper_bound_expr, lbi, ubi)
       
      
     
diff --git a/tf_verify/tensorflow_translator.py b/tf_verify/tensorflow_translator.py
index 58a2027..79a77e4 100644
--- a/tf_verify/tensorflow_translator.py
+++ b/tf_verify/tensorflow_translator.py
@@ -41,7 +41,7 @@ def tensorshape_to_intlist(tensorshape):
 	output : list
 	    list of ints corresponding to tensorshape
 	"""
-	return list(map(lambda j: 1 if j is None else int(j), tensorshape))
+	return list(map(lambda j: 1 if j.value is None else int(j), tensorshape))
 
 
 def calculate_padding(padding_str, image_shape, filter_shape, strides):
