import sys
import os

from consts import *
from data import *
from network import *
from utils import *

def run_network(network, X, Y, readonce, noise_size, sess):
    step = 0
    global_minimum_point, local_minimum_point = False, False
    while not global_minimum_point and not local_minimum_point and step < MAX_STEPS:
        global_minimum_point, local_minimum_point, non_zero_loss_sample_counter = network.update_network(sess, LR)
        step += 1
        if step % PRINT_STEP_JUMP == 0 and step > 0:
            print("Step number: {0}, Accuracy: {1} / {2}".format(step, X.shape[0] - non_zero_loss_sample_counter, X.shape[0]))
    if local_minimum_point and not global_minimum_point:
        print("Got to local minima")
    return global_minimum_point

def main():

    all_combinations = get_all_combinations()
    r = len(all_combinations)
    X = np.array(all_combinations, dtype=TYPE)

    for epsilon in range(MIN_EPSILON, MAX_EPSILON, STEP_EPSILON):
        for dnf_size in range(2, D + 1):
            noise_size = D - dnf_size
    
            print("Generate all balanced partitions in size {}".format(dnf_size))
            all_partitions = get_all_balanced_partitions(dnf_size)
            # remove the DNF of all 1 and the DNF with one term
            all_partitions = all_partitions[1:-1]
            if len(all_partitions) == 0:
                print("No relevant phrases. Skippinig..")
            
            else:
                print("Running for all dnf in size: {0} and initialization: {1}".format(dnf_size, epsilon))            
                for partition in all_partitions:
                    readonce = ReadOnceDNF(partition)
                    Y = np.array([readonce.get_label(x) for x in X], dtype=TYPE)
                    tf.reset_default_graph()
                    with tf.Graph().as_default():

                        W_init = np.array(all_combinations, dtype=TYPE) * epsilon
                        B_init = np.zeros([r], dtype=TYPE)

                        network = Network(W_init, B_init)
                        network.prepere_update_network(X, Y)

                        with tf.Session() as sess:
                            # init params
                            init = tf.initialize_all_variables()
                            sess.run(init)
                            minimum_point = run_network(network, X, Y, readonce, noise_size, sess)
                    
                    tf.reset_default_graph()
                    with tf.Graph().as_default():
                        if minimum_point:
                            print("We got to minimum point")    
                            print("Prone the network by inf norm")
                            above_mean_indexes = find_indexes_above_half_of_max(network, PRUNE_FACTOR_WEIGHT, PRUNE_FACTOR_TOTAL_NORM)
                            W_prune = network.W[above_mean_indexes]
                            B_prune = network.B[above_mean_indexes]
                            prune_network = Network(W_prune, B_prune)
                            prune_network.prepere_update_network(X, Y)
                            with tf.Session() as sess:
                                init = tf.initialize_all_variables()
                                sess.run(init)
                                if not check_reconstruction(prune_network, readonce, noise_size, RECONSTRUCTION_FACTOR_WEIGHT, RECONSTRUCTION_FACTOR_NORM):
                                    print("After pruning we can't succeeded to reconstruct the DNF")
                                else:
                                    print("Theorem proved for D={0} and target DNF: {1} using epsilon={2}".format(D, partition, epsilon))
                        else:
                            print("Got to local minima")
    
main()  