
import numpy as np
from pymoo.util.misc import stack
from pymoo.model.problem import Problem

    
from tqdm import tqdm
from PIL import Image
import glob
import numpy as np
import random as rand
import math

import difflib

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Dropout, Flatten, Softmax, GlobalAveragePooling2D
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import add,concatenate,Dot
from tensorflow.keras import Input,Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import SGD
from tensorflow.python.framework import ops
from sklearn.metrics import log_loss
import gc
import tensorflow as tf
from tensorflow.keras import backend as K
from losses import dice_loss,mean_iou
# import logging
# logging.getLogger("tensorflow").setLevel(logging.ERROR)
import operator
from dataloaders.datasetFromSequence import DatasetFromSequenceClass 

METRIC_OPS = [operator.__lt__, operator.__gt__]
METRIC_OBJECTIVES = [min, max]

class PymooGenomeReduced(Problem):
    
    def __init__(self, 
                 max_conv_layers, 
                 max_filters,
                 max_dense_layers,
                 max_nodes,
                 input_shape, 
                 n_classes,
                 batch_size=32,
                 batch_normalization=True,
                 dropout=True, 
                 max_pooling=True,
                 optimizers=None,
                 activations=None,
                 skip_ops=None,
                 type_problem='autoencoder',
                 TRAIN_WITH_LOGITS = False,
                 NASWOT = False,
                 smaller_ss = True
                ):
        self.smaller_ss = smaller_ss
        self.max_filters = max_filters
        self.max_nodes = max_nodes
        if max_conv_layers < 1:
            raise ValueError(
                "At least one conv layer is required for AE to work"
            )
        if max_filters > 0:
            filter_range_max = int(math.log(max_filters, 2)) + 1
        else:
            filter_range_max = 0
        self.optimizer = optimizers or [
            'adadelta',
            'adam'
        ]
        self.activation = activations or [
            'relu',
        ]
        self.skip_op = skip_ops or[
            'none',
            'add',
            'concatenate'
        ]

        self.convolutional_layer_shape = [
            "active",
            "kernel_size",
            "activation",
            "max pooling",
            "connections"
        ]
        self.convolutional_id_to_param = {
            "active" : 0,
            "kernel_size": 1,
            "activation" : 2,
            "max pooling": 3,
            "connections": 4
        }

        self.dense_id_to_param = {
            "active" : 0,
            "num filters" : 1,
            "activation" : 2,
        }
        self.layer_params = {
            "active": [0, 1],
            "num filters": [2**i for i in range(int(filter_range_max-5), filter_range_max)],
            #Added after paper release
            "kernel_size": [1,3,5],
            "activation": list(range(len(self.activation))),
            "max pooling": list(range(2)) if max_pooling else 0,
            "connections": [i for i in range(1,2**(max_conv_layers-1))],
        }
        
        self.dense_layer_shape = [
            "active",
            "num filters",
            "activation",
        ]
        

        self.convolution_layers = max_conv_layers
        self.convolution_layer_size = len(self.convolutional_layer_shape)
        
        self.dense_layers = max_dense_layers
        self.dense_layer_size = len(self.dense_layer_shape)
        
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.batch_size = batch_size
        self.type_problem = type_problem
        self.TRAIN_WITH_LOGITS = TRAIN_WITH_LOGITS
        self.i_model = 0
        
        #For archive lookup
        self.generation_performances = []
        self.generation_archive = []
        self.generation_members = []
        self.last_upsampling_index = 0

        self.NASWOT = NASWOT
        

        # conv_layers*conv_layer_size + 1 variables
        # 3 objectives = loss, compression, complexity
        number_of_variables = (self.convolution_layers*self.convolution_layer_size) + (self.dense_layers*self.dense_layer_size) +1
        number_of_objectives = 2
        close_to_one = 0.9999999999999
        super().__init__(n_var=number_of_variables,
                         n_obj=number_of_objectives,
                         n_constr=0,
                         xl=np.array(
                             [0 for i_layer in range(self.convolution_layers) for param in self.convolutional_layer_shape] + 
                             [0 for i_layer in range(self.dense_layers) for param in self.dense_layer_shape] +
                             [0]),
                         xu=np.array(
                                     [max((2**(max_conv_layers-i_layer-1))-2,0) if(param=='connections') else len(self.layer_params[param])-1 for i_layer in range(max_conv_layers) for param in self.convolutional_layer_shape] +
                                     [len(self.layer_params[param])-1 for i_layer in range(self.dense_layers) for param in self.dense_layer_shape] +
                                     [len(self.optimizer)-1]
                                    ),

                         elementwise_evaluation=True,
                         type_var=int)
        
    def feed_data(self,
                  train_with_gen,
                  dataset,
                  num_generations,
                  pop_size,
                  pss,
                  metric='loss',
                  batch_size = 32,
                  epochs = 5,
                  gen_to_tf_data=True,
                  multilabel=False,
                  normalize=None
                 ):
        self.multilabel = multilabel
        self.gen_to_tf_data = gen_to_tf_data
        #Previously from NE
        self.train_with_gen = train_with_gen
        self.dataset = dataset
        self.num_generations = num_generations
        if(train_with_gen):
            self.train_with_gen = True
            self.train_gen,self.val_gen = dataset
        self.generation_times = []
        self.generation_performances = []
        self.generation_archive = []
        self.generation_members = []
        self.pss = pss
        self.pop_size = pop_size
        if(self.pop_size == None and self.pss != None):
            raise ValueError("Please specify pop_size if you use PSS")
#         self.set_objective(metric)
        # If no validation data is given set it to None
        if(not self.train_with_gen):
            if len(dataset) == 2:
                (self.x_train, self.y_train), (self.x_test, self.y_test) = dataset
                self.x_val = None
                self.y_val = None
            else:
                (self.x_train, self.y_train), (self.x_test, self.y_test), (self.x_val, self.y_val) = dataset
            self.x_train_full = self.x_train.copy()
            self.y_train_full = self.y_train.copy()
        
        self.set_objective(metric)
        self.epochs = epochs
        self.batch_size = batch_size
        self.normalize = normalize

        
    
    def get_problem(self):
        return self.problem
    
    def _evaluate(self, x, out, *args, **kwargs):
        self.skip_model = False
         #Real version
        x_list = list(x)
        if(x_list in self.generation_members):
            s1 = x
            avg_similarity = 0
            for arch in self.generation_members:   
                sm=difflib.SequenceMatcher(None,s1,arch)
                similarity = sm.ratio()
                avg_similarity+=similarity
            avg_similarity = avg_similarity/(len(self.generation_members))

            perf = self.generation_performances[self.generation_members.index(x)]
            print('Skipped evaluation')
            out['LC'] = perf['LC'].copy()
            # out["F"] = perf["F"] + [avg_similarity]
            out["F"] = perf["F"].copy()
            out["acc"] = perf["acc"].copy()
            out["M"] = perf["M"].copy()
        else:
            # genome = [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.convolution_layer_size)])])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.layer_params)]
            genome = [(self.layer_params[param][x[i_param + (i_layer*self.convolution_layer_size)]])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.convolutional_layer_shape)]
            conv_layers_len = len(genome)
            if(self.type_problem== 'classification'):
                genome+= [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.dense_layer_size) + conv_layers_len])])for i_layer in range(self.dense_layers) for i_param,param in enumerate(self.dense_layer_shape)] 
            genome+= [math.floor(x[-1])]
            
            #Decode model
            model,level_of_compression,level_of_complexity = None,None,None
            try:
                model,level_of_compression,level_of_complexity = self.decode(genome)
            except Exception as e:
                print('======== CANT DECODE ============')
                import traceback
                traceback.print_exc()
                print(e)

            v = 0
            if(self.NASWOT):
                # NASWOT can currently be used only with PSS
                # i_pss = self.i_model//self.pop_size
                # if(i_pss>=len(self.train_gen)):
                #     i_pss = max(-len(self.train_gen)+1, len(self.train_gen)-i_pss-1)
                # ds = self.train_gen[self.i_model//self.pop_size].__getitem__(0
                ds = None
                if(self.pss):
                    igen = int(self.i_model/self.pop_size)
                    ds = self.train_gen[igen%self.pss].__getitem__(0)
                else:
                    igen = int(self.i_model/self.pop_size)
                    ds = self.train_gen.__getitem__(igen)
                

                x_naswot = ds[0]
                y_naswot = ds[1]
                bs = len(x_naswot)
                model.K = np.zeros((bs,bs))
                naswot_score = 1
                preds = model.predict(x_naswot)
                if(type(preds)==type([])):
                    for l_o in preds:
                        l_o_temp = l_o.view()
                        if(len(l_o.shape)>2):
                            l_o_temp = l_o_temp.reshape(bs,-1)
                        x = (l_o_temp > 0)
                        K_temp = x @ x.transpose()
                        K2_temp = (1.-x) @ (1.-x.transpose())
                        model.K = model.K + K_temp + K2_temp
                else:
                    l_o_temp = preds.view()
                    if(len(l_o_temp.shape)>2):
                        l_o_temp = l_o_temp.reshape(bs,-1)
                    x = (l_o_temp > 0)
                    K_temp = x @ x.transpose()
                    K2_temp = (1.-x) @ (1.-x.transpose())
                    model.K = model.K + K_temp + K2_temp
                if(len(np.unique(model.K))>1):
                    s, naswot_score = np.linalg.slogdet(model.K)
                v = naswot_score*-1
                del model
                gc.collect()

            else:
                #Initialise performance metrics list
                performance = []
                
                #Define callbacks
                callbacks = [
                        EarlyStopping(monitor='val_loss', patience=1, verbose=0)
                    ]
                
                epochs = self.epochs
                
                #Define fir parameters
                if(not self.train_with_gen):
                    fit_params = {
                        'x': self.x_train_full,
                        'y': self.y_train_full,
                        'validation_split': 0.2,
                        'batch_size':self.batch_size,
                        'shuffle':True,
                        'steps_per_epoch': int(len(self.x_train_full)/self.batch_size),
                        'epochs': epochs,
                        'verbose': 0,
                        'callbacks': callbacks
                    }
                    if self.x_val is not None:
                        fit_params['validation_data'] = (self.x_val, self.y_val)
                        
                #Initialise proxy score
                sc = 0
                # print(model.summary())
                try:
                    if(self.pss):
                        igen = int(self.i_model/self.pop_size)
                        if(self.gen_to_tf_data):
                            temp_tr_gen = self.train_gen[igen%self.pss]
                            nEpochs = 1
                            out_dims = []
                            if(self.type_problem=='classification'):
                                out_dims = [self.n_classes]
                            else:
                                out_dims = self.input_shape[:-1]+[self.n_classes]
                            self.train_tf_gen = DatasetFromSequenceClass(temp_tr_gen, len(temp_tr_gen), nEpochs, self.batch_size, dims=self.input_shape, out_dims=out_dims)
                            self.val_tf_gen = DatasetFromSequenceClass(self.val_gen, len(self.val_gen), nEpochs, self.batch_size, dims=self.input_shape, out_dims=out_dims)
                            history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), validation_data=self.val_tf_gen, validation_steps=len(self.val_gen), callbacks=callbacks, verbose=0)
                            performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=0)
                            print(performance)
                        else:
                            history = model.fit(self.train_gen[igen%self.pss],epochs=epochs, validation_data=self.val_gen, callbacks=callbacks, verbose=0)
                            performance = model.evaluate(self.val_gen, verbose=1)
                        sc = 0
                    else:
                        if(self.train_with_gen):
                            training = self.train_gen
                            testing = self.val_gen
                            if(self.gen_to_tf_data):
                                nEpochs = 1
                                out_dims = []
                                if(self.type_problem=='classification'):
                                    out_dims = [self.n_classes]
                                else:
                                    out_dims = self.input_shape[:-1]+[self.n_classes]
                                self.train_tf_gen = DatasetFromSequenceClass(self.train_gen, len(self.train_gen), nEpochs, self.batch_size, dims=list(self.input_shape), out_dims=out_dims)
                                self.val_tf_gen = DatasetFromSequenceClass(self.val_gen, len(self.val_gen), nEpochs, self.batch_size, dims=list(self.input_shape), out_dims=out_dims)
                                self.train_tf_gen = self.train_tf_gen.shuffle(buffer_size=50).prefetch(buffer_size=50).cache()
                                self.val_tf_gen = self.val_tf_gen.prefetch(buffer_size=50).cache()
                                if(self.normalize):
                                    self.train_tf_gen = self.train_tf_gen.map(self.normalize)
                                    self.val_tf_gen = self.val_tf_gen.map(self.normalize)
                                # history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), validation_data=self.val_tf_gen, validation_steps=len(self.val_gen), callbacks=callbacks, verbose=0)
                                # performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=1)
                                #TODO Temp for SIXRAY
                                # history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), callbacks=callbacks, verbose=0)
                                # performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=1)
                                training =  self.train_tf_gen
                                testing =  self.val_tf_gen

                            history = model.fit(training,epochs=epochs,callbacks=callbacks, verbose=0)
                            print(history.history.keys())
                            performance = model.evaluate(testing,verbose=0)
                            print(performance)
                            # Proxy score not used
                            # X_batch, Y_batch = next(iter(self.val_gen))
                            # b_j = get_batch_jacobian(model,X_batch)
                            # b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                            # #Proxy score
                            # sc = eval_score(b_j_2,Y_batch)
                            sc = 0
                        else:
                            history = model.fit(**fit_params)
                            performance = model.evaluate(self.x_test, self.y_test, batch_size=self.batch_size, verbose=0)
                            # Proxy score not used
                            #b_j = get_batch_jacobian(model,self.x_val)
                            #b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                            ##Proxy score
                            #sc = eval_score(b_j_2,self.y_val)
                        # if(self.gen_to_tf_data):
                        #     del self.train_tf_gen
                        #     del self.val_tf_gen
                        #     del training
                        #     del testing
                            
                except Exception as e:
                    print(e)
                    performance = self._handle_broken_model(model, e)
                finally:
                    if(self.gen_to_tf_data):
                            del self.train_tf_gen
                            del self.val_tf_gen
                            del training
                            del testing
                
                # avg_similarity = 0
                # if(len(self.generation_members)>1):
                #     s1 = x
                #     for arch in self.generation_members:   
                #         sm=difflib.SequenceMatcher(None,s1,arch)
                #         similarity = sm.ratio()
                #         avg_similarity+=similarity
                #     avg_similarity = avg_similarity/(len(self.generation_members))
                print(np.array(performance).shape)
                v = min(performance[0],4)
            # try:
            #     if(self.num_generations>40):
            #         if(self.i_model/self.pop_size>self.num_generations-4):
            #             model.save("model-{}.h5".format(self.i_model))
            #     else:
            #         model.save("model-{}.h5".format(self.i_model))
            # except:
            #     print('MODEL {} COUNT NOT BE SAVED!'.format(self.i_model))
            #     pass
            perf = dict()
            level_of_complexity/=10
            out['LC'] = level_of_complexity
            perf['LC'] = level_of_complexity
            if(level_of_complexity<0.60):
                level_of_complexity = 2-level_of_complexity
            print(self.i_model,level_of_complexity)
            print('+++++++++++++++++++')
            # level_of_complexity = max(0.62, level_of_complexity)
            self.i_model+=1
            # out["F"] = [v,level_of_complexity, avg_similarity]
            out["F"] = [v,level_of_complexity]
            perf["F"] = [v,level_of_complexity]
            # perf["G"] = 
            # out["acc"] = performance
            # perf["acc"] = performance
            out["M"] = "model-{}".format(self.i_model)
            perf["M"] = "model-{}".format(self.i_model)
            self.generation_members.append(x_list)
            self.generation_performances.append(perf)
            # del self.train_tf_gen
            # del self.val_tf_gen

    

    def decode(self, genome):
        # print([genome[8 + i*self.convolution_layer_size] for i in range(self.convolution_layers) if genome[i*convolution_layer_size]==1])
        if not self.is_compatible_genome(genome):
            raise ValueError("Invalid genome for specified configs")
        cons = [genome[self.convolutional_id_to_param['connections'] + i*self.convolution_layer_size] for i in range(self.convolution_layers) if genome[i*self.convolution_layer_size]==1]
        lays = []
        x = None
        dim = 0
        offset = 0
        optim_offset = 0
        if self.convolution_layers > 0:
            dim = min(self.input_shape[:-1])  # keep track of smallest dimension
        input_layer = True
        dims = []
        gateways = dict()
        temp_features = self.max_filters
        features = dict()
        naswot_outputs = []

        x = Input(shape=self.input_shape,dtype=tf.float32)
        add_layer(cons,lays, x, 0)
        for i in range(self.convolution_layers):
            if genome[offset]:
                if input_layer:
                    if(not self.smaller_ss):
                        temp_features = genome[offset + self.convolutional_id_to_param['num filters']]
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    x =  Convolution2D(
                        temp_features, (temp_kernel, temp_kernel),
                        padding='same',
                        kernel_initializer = 'he_normal',
                        # V12
#                         input_shape=self.input_shape,
                        activation=self.activation[genome[offset + self.convolutional_id_to_param['activation']]]
                    )
                    lays.append(x)
                    input_layer = False
                else:
                    if(not self.smaller_ss):
                        if(self.type_problem =='classification'):
                            temp_features = genome[offset + self.convolutional_id_to_param['num filters']]
                        else:
                            temp_features = int(min(features[list(features.keys())[-1]],genome[offset + self.convolutional_id_to_param['num filters']]))
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    x = Convolution2D(
                        temp_features, (temp_kernel, temp_kernel),
                        padding='same',
                        kernel_initializer = 'he_normal',
                        activation=self.activation[genome[offset + self.convolutional_id_to_param['activation']]]
                    )
                    lays.append(x)
                # if genome[offset + self.convolutional_id_to_param['batch normalisation']]:
                #     x = BatchNormalization()
                #     add_layer(cons,lays,x,len(lays))
                x = BatchNormalization()
                add_layer(cons,lays,x,len(lays))
                # x = Activation(self.activation[genome[offset + 4]])
                # add_layer(cons,lays,x,len(lays))
                #Append the gateway to layer for skip connection BEFORE pooling
                # if(not self.skip_op[genome[offset+7]]=='none'):
                #     gateways[offset]=((dim,x))
                max_pooling_type = genome[offset + self.convolutional_id_to_param['max pooling']]
                if max_pooling_type == 1 and dim >= 4:
                    x = MaxPooling2D(pool_size=(2, 2), padding="same")
                    add_layer(cons,lays,x,len(lays))
                    dim /= 2
            dims.append(dim)
            features[i] = temp_features
            dim = int(math.ceil(dim))
            if(i<self.convolution_layers-1):
                offset += self.convolution_layer_size
            else:
                optim_offset = offset + self.convolution_layer_size
        # x = Convolution2D(temp_features,(3,3),padding='same', kernel_initializer='he_uniform')
        # add_layer(cons,lays,x,len(lays))
        # level_of_compression = np.prod(x.get_shape()[1:])
        #level_of_compression is limited to 10 instead of 5 in the original MONCAE paper!
        # level_of_compression = min(math.log(level_of_compression,10),10)
        # TODO TEMP disbaled loc
        level_of_compression = 10
        needed_reductions = [i-2 for i,temp_dim in enumerate(dims) if(math.ceil(temp_dim)!=math.floor(temp_dim))]
        if(self.type_problem=='classification'):
            x = GlobalAveragePooling2D()
            add_layer(cons,lays,x,len(lays))
        if(self.type_problem=='ss' or self.type_problem=='ae'):
            #Reset the offset
            for i in reversed(range(self.convolution_layers)):
                #Done to fix shape when 14->7-> 4 => 4->8->16->14
                if(not(dim in dims) and ((dim-2)*2 in dims or (not(dim*2 in dims) and (dim-2)*2==min(self.input_shape[:-1])))):
                    x = Convolution2D(features[i],(3,3))
                    add_layer(cons,lays,x,len(lays))
                    dim-=2
                if genome[offset]:
                    skipped = False
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    max_pooling_type = genome[offset + self.convolutional_id_to_param['max pooling']]
                    x = Convolution2D(
                        features[i], (temp_kernel, temp_kernel),
                        padding='same',
                        kernel_initializer = 'he_normal',
                        activation=self.activation[genome[offset + self.convolutional_id_to_param['activation']]],
                    )
                    add_layer(cons,lays,x,len(lays))
                    # if(not self.type_problem=='ss'):
                    #     x = Activation(self.activation[genome[offset + 4]])
                    #     add_layer(cons,lays,x,len(lays))
                    if (((dim*2 in dims or (dim*2)-2 in dims or (dim*4)-2 or dim==(int(min(self.input_shape[:-1]))/2))) and dim<min(self.input_shape[:-1])):
                        x = UpSampling2D((2, 2))
                        add_layer(cons,lays,x,len(lays))
                        self.last_upsampling_index = len(lays)-1
                        dim*=2
                if(dim>max(self.input_shape)):
                    import pdb
                    pdb.set_trace()
                offset -= self.convolution_layer_size
            if(not self.type_problem=='ss'):
                x = Convolution2D(self.input_shape[-1], (genome[self.convolutional_id_to_param['kernel_size']],genome[self.convolutional_id_to_param['kernel_size']]), activation=self.activation[genome[self.convolutional_id_to_param['activation']]], padding='same', )
                add_layer(cons,lays,x,len(lays))
            else:
                x = Convolution2D(self.n_classes, self.input_shape[-1], activation="softmax", padding="same",)
                add_layer(cons,lays,x,len(lays))
        #Clear connections
        dirty_cons = None
        clean_cons = None
        # print(lays)
        try:
            dirty_cons = decode_connections(cons,len(cons))
            clean_cons = clear_cons(dirty_cons,len(cons))
        except Exception as ex:
            print(cons,len(cons))
            print('Failed connections!')
            print(ex)
            
        operations = []
        operations = self.decode_ops(operations,lays,clean_cons)
        if(self.type_problem=='classification'):
            ### Dense layer decoding (classification only)
            
            offset = optim_offset
            has_a_dense_active = False
            for i in range(self.dense_layers):
                if genome[offset]:
                    temp_nodes = genome[offset + self.dense_id_to_param['num filters']]
                    x =  Dense(
                        temp_nodes,
                        kernel_initializer = 'he_normal',
                        activation=self.activation[genome[offset + self.dense_id_to_param['activation']]]
                    )
                    operations.append(x(operations[-1]))
                    has_a_dense_active = True
                offset+= self.dense_layer_size
            if(not has_a_dense_active):
                x = Dense(self.n_classes)
                operations.append(x(operations[-1]))
            if(self.multilabel =='true'):
                x = Dense(self.n_classes, activation='sigmoid')
            else:
                x = Dense(self.n_classes, activation='softmax')
            operations.append(x(operations[-1]))
            optim_offset = offset
        outs = list()  
        # import pdb
        # pdb.set_trace()
        if(self.NASWOT):
            outs = [op for op in operations if 'relu' in op.name or 'Relu' in op.name]
            if(len(outs)<1):
                outs=operations[-1]
        else:
            outs=operations[-1]
        model = Model(operations[0],outs)
        # TODO changed from binary_crossentropy
        metrics = ["accuracy"]
        # print(self.type_problem, self.TRAIN_WITH_LOGITS)
        if(self.TRAIN_WITH_LOGITS):
            loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        elif(self.type_problem=='ss' or self.type_problem=='ae'):
            loss = 'sparse_categorical_crossentropy'
        else:
            loss = 'categorical_crossentropy'
        if(self.type_problem =='ss'):
            # loss = dice_loss
            metrics += [mean_iou]
        #TODO CHANGE THAT!!!!
        if(True):
            id_to_name={
                0:'Gun',
                1:'Knife',
                2:'Wrench',
                3:'Pliers',
                4:'Scissors'
            }
            metrics+= [tf.keras.metrics.Precision(class_id=idx, name='precision_{}'.format(id_to_name[idx])) for idx in range(len(id_to_name))]
        if(not self.NASWOT):
            model.compile(loss=loss,
                        optimizer=self.optimizer[genome[optim_offset]],
                        metrics=metrics)
        level_of_complexity = min(math.log(int(model.count_params()),10),10)
        return model,level_of_compression,level_of_complexity
    
    def decode_ops(self,operations=list(),lays=list(),cons=list()):
        index_fixer = len(operations)
        operations = operations.copy()
        for index,con in enumerate(cons.transpose()):
            #First layer in cell
            if(index==0):
                operations.append(lays[0])
            else:
                #TODO FIND MORE OPTIMIAL WAY OF DOING THIS!!!
                nz = np.nonzero(con)[0]
                if(len(nz)>1):
                    shapes = {operations[layer+index_fixer].shape[-1] for layer in nz}
                    k_shapes = {operations[layer+index_fixer].shape[1] for layer in nz}
                    full_shapes = {(operations[layer+index_fixer].shape[1],operations[layer+index_fixer].shape[3]) for layer in nz}
                    can_add = True

                    if(len(full_shapes)==1):
                        op = add([operations[layer+index_fixer] for layer in nz])
                    else:
                        # print('Gonna try')
                        adjustment_ops = []
                        layers_for_addition = [operations[layer+index_fixer] for layer in nz]
                        lowest_dim_ind = 0
                        lowest_dim = layers_for_addition[0].shape[1]
                        if(self.type_problem=='ss'):
                            lowest_dim = layers_for_addition[len(layers_for_addition)-1].shape[1]
                            lowest_dim_ind = len(layers_for_addition)-1
                        else:
                            for i_l,l in enumerate(layers_for_addition):
                                if(l.shape[1]<lowest_dim):
                                    lowest_dim = l.shape[1]
                                    lowest_dim_ind = i_l
                        for i_l,l_to_add in enumerate(layers_for_addition):
                            adjust_op = l_to_add
                            tries = 0
                            while(adjust_op.shape[1] != layers_for_addition[lowest_dim_ind].shape[1] and can_add):
                                if(adjust_op.shape[1]<layers_for_addition[lowest_dim_ind].shape[1]):
                                    adjust_op = UpSampling2D((2, 2))(adjust_op)
                                else:
                                    adjust_op = MaxPooling2D(pool_size=(2, 2), padding="same")(adjust_op)
                                tries+=1
                                if(tries>10):
                                    can_add=False
                                    print('Cannot add {} and {} and {}'.format(l_to_add.shape, adjust_op.shape ,layers_for_addition[lowest_dim_ind].shape))
                            if(i_l!=lowest_dim_ind):
                                adjust_op = Convolution2D(layers_for_addition[lowest_dim_ind].shape[-1],kernel_size=(1, 1), padding="same")(adjust_op)
                            adjustment_ops.append(adjust_op)
                        if(can_add):
                            op = add([adj_op for adj_op in adjustment_ops])
                        # print('We did it?')
                    # elif((len(k_shapes)==1) and len(shapes)!=1):
                    #     desired_shape = operations[-1].shape[-1]
                    #     op = concatenate([operations[layer+index_fixer] for layer in nz])
                    #     #Fix shape with identity
                    #     op = Convolution2D(desired_shape,(1,1),padding='same')(op)
                    if(can_add == False):
                        operations.append(lays[index](operations[-1]))
                        continue
                    else:
                        operations.append(lays[index](op))
                elif(len(nz)==1):
                    operations.append(lays[index](operations[-1]))
                else:
                    print('======ERRORRORORORR ========')
#                     pdb.set_trace()
                    continue
        return operations
    
    def decode_model_genome(self, genome):
        x = genome
        genome = [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.convolution_layer_size)])])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.layer_params)]
        if(self.type_problem== 'classification'):
            genome+= [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.dense_layer_size)])])for i_layer in range(self.dense_layers) for i_param,param in enumerate(self.dense_layer_shape)]
            
        genome+= [math.floor(x[-1])]
        
        #Decode model
        try:
            model,level_of_compression,level_of_complexity = self.decode(genome)
        except:
            raise Exception('Model could not be decoded')
        return model
        
    
    def convParam(self, i):
        key = self.convolutional_layer_shape[i]
        return self.layer_params[key]
    
    def denseParam(self, i):
        key = self.dense_layer_shape[i]
        return self.layer_params[key]
        
    def is_compatible_genome(self, genome):
        expected_len = (self.convolution_layers * self.convolution_layer_size) + (self.dense_layers * self.dense_layer_size) + 1
        if len(genome) != expected_len:
            return False
        ind = 0
        for i in range(self.convolution_layers):
            for j in range(self.convolution_layer_size):
                if genome[ind + j] not in self.convParam(j):
                    return False
            ind += self.convolution_layer_size
            
        for i in range(self.dense_layers):
            for j in range(self.dense_layer_size):
                if genome[ind + j] not in self.denseParam(j):
                    return False
            ind += self.dense_layer_size
            
        if genome[ind] not in range(len(self.optimizer)):
            return False
        
        return True
    
    def _handle_broken_model(self, model, error):
        self.skip_model = True
        print('================')
        print('Number of parameters:', str(model.count_params()))
        print('================')

        n = self.n_classes
        # v2 Added loss 10 times more for models out of score to make them infavourable
        # performance = [log_loss(np.concatenate(([1], np.zeros(n - 1))), np.ones(n) / n)*10, math.log((self.input_shape[1]*self.input_shape[1]),10), 1]
        performance = [10.0] + [K.epsilon()] * (len(model.metrics)-1)
        # del model
        gc.collect()

        # if K.backend() == 'tensorflow':
        #     K.clear_session()
        #     #Changed from tensorflow
        #     ops.reset_default_graph()

        print('An error occurred and the model could not train!')
        print(('Model assigned poor score. Please ensure that your model'
               'constraints live within your computational resources.'))
        return performance
    
    def set_objective(self, metric):
        """
        Set the metric for optimization. Can also be done by passing to
        `run`.

        Args:
            metric (str): either 'acc' to maximize classification accuracy, or
                    else 'loss' to minimize the loss function
        """
        if metric not in ['loss', 'hvi']:
            raise ValueError(('Invalid metric name {} provided - should be'
                              '"accuracy" or "loss"').format(metric))
        self._metric = metric
        self._objective = "max" if self._metric == "hvi" else "min"
        #TODO currently loss and accuracy
        self._metric_index = 0 
        self._metric_op = METRIC_OPS[self._objective == 'max']
        self._metric_objective = METRIC_OBJECTIVES[self._objective == 'max']

    
    

        
def add_layer(cons, lays, layer,pos):
    size = len(cons)
    con = int(2**((size)-(pos+1)))
    cons.insert(pos,con)
    lays.insert(pos,layer)
    for i in range(0,pos):
        positional = 2**(size-(i+1))
        if(cons[i]<positional):
            if(positional==1):
                cons[i]==0
            elif(cons[i]>=(positional/2)):
                cons[i] = int(cons[i]- (positional/2))
            cons[i] = cons[i]+positional
    return cons, lays

def decode_connections(cons,cell_size):
    bin_cons = list()
    for i_con,con in enumerate(cons):
        overflow = 2**(cell_size)
        tries = 0
        enough = False
        while(cons[i_con]>=overflow and not enough):
            cons[i_con] -= overflow
            tries+=1
            if(tries>cell_size*10):
                enough = True

    for con in cons:
        binarised = bin(con)[2:]
        bin_cons.append([int(digit) for digit in eval("f\"{binarised:0>"+str(cell_size)+"}\"")])
    # print(bin_cons)
#     bin_cons.append([0] * cell_size)
    return np.array(bin_cons, dtype=object)

def clear_cons(dirty_cons,cell_size):
    # Disabled v4

    # b = np.ones(cell_size)
    # np.fill_diagonal(dirty_cons[:,1:], b)
    clean_cons = np.triu(dirty_cons, k=1)
    for i in range(1,len(clean_cons)):
        # Added v5 to prevent disconnected layer which was connected but in a wrong way
        if(1 not in clean_cons[i,:] and 1 in dirty_cons[i,:]):
            clean_cons[i,i+1] = 1
        if(1 not in clean_cons[:,i]):
            clean_cons[0,i] = 1
    return clean_cons
    