
import enum
import numpy as np
from pymoo.util.misc import stack
from pymoo.model.problem import Problem

    
from tqdm import tqdm
from PIL import Image
import glob
import numpy as np
import random as rand
import math

import difflib

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Dropout, Flatten, Softmax, GlobalAveragePooling2D
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, UpSampling2D, Conv2DTranspose
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import add,concatenate,Dot
from tensorflow.keras import Input,Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import SGD
from tensorflow.python.framework import ops
from tensorflow.keras.metrics import MeanIoU
from sklearn.metrics import log_loss
import gc
import tensorflow as tf
from tensorflow.keras import backend as K
from losses import dice_loss,mean_iou,MeanIOUWrapper
# import logging
# logging.getLogger("tensorflow").setLevel(logging.ERROR)
import operator
from dataloaders.datasetFromSequence import DatasetFromSequenceClass 
from dataloaders.datasetFromSequenceCityScapes import DatasetFromSequenceCityScapes

from proxies import get_synflow_score


#Regularisation
from tensorflow.keras.regularizers import l2
# kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01)))

METRIC_OPS = [operator.__lt__, operator.__gt__]
METRIC_OBJECTIVES = [min, max]


class PymooGenomeReduced(Problem):
    
    def __init__(self, 
                 max_conv_layers, 
                 max_filters,
                 max_dense_layers,
                 max_nodes,
                 input_shape, 
                 n_classes,
                 batch_size=32,
                 batch_normalization=True,
                 dropout=True, 
                 max_pooling=True,
                 optimizers=None,
                 activations=None,
                 skip_ops=None,
                 type_problem='autoencoder',
                 TRAIN_WITH_LOGITS = False,
                 NASWOT = False,
                 SYNFLOW = False,
                 smaller_ss = True
                ):
        self.smaller_ss = smaller_ss
        self.max_filters = max_filters
        self.max_nodes = max_nodes
        if max_conv_layers < 1:
            raise ValueError(
                "At least one conv layer is required for AE to work"
            )
        if max_filters > 0:
            filter_range_max = int(math.log(max_filters, 2)) + 1
        else:
            filter_range_max = 0
        self.optimizer = optimizers or [
            'adadelta',
            'adam'
        ]
        self.activation = activations or [
            'relu',
        ]
        self.skip_op = skip_ops or[
            'none',
            'add',
            'concatenate'
        ]

        self.convolutional_layer_shape = [
            "active",
            "kernel_size",
            "activation",
            "max pooling",
            "connections"
        ]
        self.convolutional_id_to_param = {
            "active" : 0,
            "kernel_size": 1,
            "activation" : 2,
            "max pooling": 3,
            "connections": 4
        }

        self.dense_id_to_param = {
            "active" : 0,
            "num filters" : 1,
            "activation" : 2,
        }
        self.layer_params = {
            "active": [0, 1],
            "num filters": [2**i for i in range(int(filter_range_max-5), filter_range_max)],
            #Added after paper release
            "kernel_size": [1,3,5],
            "activation": list(range(len(self.activation))),
            "max pooling": list(range(2)) if max_pooling else 0,
            # Old all connections
            # "connections": [i for i in range(1,2**(max_conv_layers-1))],
            # New limitted connections
            "connections": [i for i in range(10)],
        }
        
        self.dense_layer_shape = [
            "active",
            "num filters",
            "activation",
        ]
        

        self.convolution_layers = max_conv_layers
        self.convolution_layer_size = len(self.convolutional_layer_shape)
        
        self.dense_layers = max_dense_layers
        self.dense_layer_size = len(self.dense_layer_shape)
        
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.batch_size = batch_size
        self.type_problem = type_problem
        self.TRAIN_WITH_LOGITS = TRAIN_WITH_LOGITS
        self.i_model = 0
        
        #For archive lookup
        self.generation_performances = []
        self.generation_archive = []
        self.generation_members = []
        self.last_upsampling_index = 0

        self.NASWOT = NASWOT
        self.SYNFLOW = SYNFLOW

        self.dropout = dropout
        self.batch_norm = batch_normalization

        max_conv_layers = self.convolution_layers
        
        self.all_p_c = [[i for i in range(max((2**(max_conv_layers-i_layer-1))-1,0)+1)] for i_layer in range(max_conv_layers)]

        self.p_c_filtered = [
            [
            self.get_filtered_con_layer(i,i_layer,self.all_p_c) for i in range(min(len(self.all_p_c[i_layer]),10))
            ] for i_layer in range(max_conv_layers)
        ]

        

        # conv_layers*conv_layer_size + 1 variables
        # 3 objectives = loss, compression, complexity
        number_of_variables = (self.convolution_layers*self.convolution_layer_size) + (self.dense_layers*self.dense_layer_size) +1
        number_of_objectives = 2
        close_to_one = 0.9999999999999
        super().__init__(n_var=number_of_variables,
                         n_obj=number_of_objectives,
                         n_constr=1,
                         xl=np.array(
                             [0 for i_layer in range(self.convolution_layers) for param in self.convolutional_layer_shape] + 
                             [0 for i_layer in range(self.dense_layers) for param in self.dense_layer_shape] +
                             [0]),
                         xu=np.array(
                                    #Old connections
                                    #  [max((2**(max_conv_layers-i_layer-1))-2,0) if(param=='connections') else len(self.layer_params[param])-1 for i_layer in range(max_conv_layers) for param in self.convolutional_layer_shape] +
                                    #New connections
                                     [len(self.layer_params[param])-1 for i_layer in range(max_conv_layers) for param in self.convolutional_layer_shape] +
                                     [len(self.layer_params[param])-1 for i_layer in range(self.dense_layers) for param in self.dense_layer_shape] +
                                     [len(self.optimizer)-1]
                                    ),

                         elementwise_evaluation=True,
                         type_var=int)
        
    def feed_data(self,
                  train_with_gen,
                  dataset,
                  num_generations,
                  pop_size,
                  pss,
                  metric='loss',
                  batch_size = 32,
                  epochs = 5,
                  gen_to_tf_data=True,
                  multilabel=False,
                  normalize=None,
                  double_up = True
                 ):
        self.multilabel = multilabel
        self.gen_to_tf_data = gen_to_tf_data
        #Previously from NE
        self.train_tf_gen = None
        self.val_tf_gen = None
        self.train_with_gen = train_with_gen
        self.dataset = dataset
        self.num_generations = num_generations
        if(train_with_gen):
            self.train_with_gen = True
            self.train_gen,self.val_gen = dataset
        self.generation_times = []
        self.generation_performances = []
        self.generation_archive = []
        self.generation_members = []
        self.pss = pss
        self.pop_size = pop_size
        if(self.pop_size == None and self.pss != None):
            raise ValueError("Please specify pop_size if you use PSS")
#         self.set_objective(metric)
        # If no validation data is given set it to None
        if(not self.train_with_gen):
            if len(dataset) == 2:
                (self.x_train, self.y_train), (self.x_test, self.y_test) = dataset
                self.x_val = None
                self.y_val = None
            else:
                (self.x_train, self.y_train), (self.x_test, self.y_test), (self.x_val, self.y_val) = dataset
            self.x_train_full = self.x_train.copy()
            self.y_train_full = self.y_train.copy()
        
        self.set_objective(metric)
        self.epochs = epochs
        self.batch_size = batch_size
        self.normalize = normalize
        self.double_up = double_up

        
    
    def get_problem(self):
        return self.problem
    
    def _evaluate(self, x, out, *args, **kwargs):

        self.skip_model = False
         #Real version
        x_list = list(x)
        if(x_list in self.generation_members):
            #Novelty score calculation to include as an objective
            # s1 = x
            # novelty_score = 0
            # for arch in self.generation_members:   
            #     sm=difflib.SequenceMatcher(None,s1,arch)
            #     similarity = sm.ratio()
            #     novelty_score+=similarity
            # novelty_score = novelty_score/(len(self.generation_members))

            perf = self.generation_performances[self.generation_members.index(x_list)]
            print('Skipped evaluation')
            out['LC'] = perf['LC']
            # out["F"] = perf["F"] + [avg_similarity]
            out["F"] = perf["F"]
            # out["acc"] = perf["acc"]
            out["M"] = perf["M"]

            out["G"] = perf["G"]
        else:
            # genome = [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.convolution_layer_size)])])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.layer_params)]
            genome = [(self.layer_params[param][x[i_param + (i_layer*self.convolution_layer_size)]])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.convolutional_layer_shape)]
            conv_layers_len = len(genome)
            if(self.type_problem== 'classification'):
                genome+= [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.dense_layer_size) + conv_layers_len])])for i_layer in range(self.dense_layers) for i_param,param in enumerate(self.dense_layer_shape)] 
            genome+= [math.floor(x[-1])]
            
            #Decode model
            model,level_of_compression,level_of_complexity = None,None,None
            model,level_of_compression,level_of_complexity = self.decode(genome)
            # try:
            #     model,level_of_compression,level_of_complexity = self.decode(genome)
            # except Exception as e:
            #     print('======== CANT DECODE ============')
            #     import traceback
            #     traceback.print_exc()
            #     print(e)

            v = 0
            if(self.SYNFLOW):
                v = 1-get_synflow_score(model,self.input_shape)
            elif(self.NASWOT):
                ds = None
                x_naswot, y_naswot = None, None
                igen = int(self.i_model/self.pop_size)
                if(self.train_with_gen):
                    if(self.pss):
                        ds = self.train_gen[igen%self.pss].__getitem__(0)
                    else:
                        ds = self.train_gen.__getitem__(igen)
                    x_naswot = ds[0]
                    y_naswot = ds[1]
                else:
                    x_naswot = self.x_train_full[igen*self.batch_size:(igen+1)*self.batch_size]
                    y_naswot = self.y_train_full[igen*self.batch_size:(igen+1)*self.batch_size]
                bs = len(x_naswot)
                model.K = np.zeros((bs,bs))
                naswot_score = 1
                preds = model.predict(x_naswot)
                if(type(preds)==type([])):
                    for l_o in preds:
                        l_o_temp = l_o.view()
                        if(len(l_o.shape)>2):
                            l_o_temp = l_o_temp.reshape(bs,-1)
                        x = (l_o_temp > 0)
                        K_temp = x @ x.transpose()
                        K2_temp = (1.-x) @ (1.-x.transpose())
                        model.K = model.K + K_temp + K2_temp
                else:
                    l_o_temp = preds.view()
                    if(len(l_o_temp.shape)>2):
                        l_o_temp = l_o_temp.reshape(bs,-1)
                    x = (l_o_temp > 0)
                    K_temp = x @ x.transpose()
                    K2_temp = (1.-x) @ (1.-x.transpose())
                    model.K = model.K + K_temp + K2_temp
                if(len(np.unique(model.K))>1):
                    s, naswot_score = np.linalg.slogdet(model.K)
                v = naswot_score*-1
                # del model
                # model.save("model-{}.h5".format(self.i_model))
                gc.collect()

            else:
                #Initialise performance metrics list
                performance = []
                
                #Define callbacks
                callbacks = [
                        # EarlyStopping(monitor='loss', patience=1, verbose=0)
                    ]
                
                epochs = self.epochs
                
                #Define fir parameters
                if(not self.train_with_gen):
                    fit_params = {
                        'x': self.x_train_full,
                        'y': self.y_train_full,
                        'validation_split': 0.1,
                        'batch_size':self.batch_size,
                        'shuffle':True,
                        'steps_per_epoch': int(len(self.x_train_full)/self.batch_size),
                        'epochs': epochs,
                        'verbose': 1,
                        'callbacks': callbacks
                    }
                    if self.x_val is not None:
                        fit_params['validation_data'] = (self.x_val, self.y_val)
                        
                #Initialise proxy score
                sc = 0
                # print(model.summary())
                try:
                    if(self.pss):
                        igen = int(self.i_model/self.pop_size)
                        # if(self.gen_to_tf_data):
                        #     temp_tr_gen = self.train_gen[igen%self.pss]
                        #     nEpochs = 5
                        #     out_dims = []
                        #     if(self.type_problem=='classification'):
                        #         out_dims = [self.n_classes]
                        #     else:
                        #         out_dims = self.input_shape[:-1]+[self.n_classes]
                        #     if(self.type_problem=='ss'):
                        #         self.train_tf_gen = DatasetFromSequenceCityScapes(temp_tr_gen, len(temp_tr_gen), nEpochs, 1).unbatch().batch(self.batch_size)
                        #         self.val_tf_gen = DatasetFromSequenceCityScapes(self.val_gen, len(self.val_gen), nEpochs, 1).unbatch().batch(self.batch_size)
                        #         history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), validation_data=self.val_tf_gen, validation_steps=len(self.val_gen), callbacks=callbacks, verbose=0)
                        #         performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=0)
                        #         print(performance)
                        #     else:
                        #         self.train_tf_gen = DatasetFromSequenceClass(temp_tr_gen, len(temp_tr_gen), nEpochs, self.batch_size, dims=self.input_shape, out_dims=out_dims)
                        #         self.val_tf_gen = DatasetFromSequenceClass(self.val_gen, len(self.val_gen), nEpochs, self.batch_size, dims=self.input_shape, out_dims=out_dims)
                        #         history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), validation_data=self.val_tf_gen, validation_steps=len(self.val_gen), callbacks=callbacks, verbose=0)
                        #         performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=0)
                        #         print(performance)
                        # else:
                        #     history = model.fit(self.train_gen[igen%self.pss],epochs=epochs, validation_data=self.val_gen, callbacks=callbacks, verbose=0)
                        #     performance = model.evaluate(self.val_gen, verbose=1)
                        # sc = 0
                    else:
                        if(self.train_with_gen):
                            training = self.train_gen
                            testing = self.val_gen
                            if(self.gen_to_tf_data):
                                nEpochs = 5
                                out_dims = []
                                if(self.type_problem=='classification'):
                                    out_dims = [self.n_classes]
                                else:
                                    out_dims = list(self.input_shape)[:-1]+[self.n_classes]
                                if(self.type_problem=='ss'):
                                    self.train_tf_gen = DatasetFromSequenceCityScapes(self.train_gen, len(self.train_gen), nEpochs, 1).unbatch().batch(self.batch_size)
                                    self.val_tf_gen = DatasetFromSequenceCityScapes(self.val_gen, len(self.val_gen), nEpochs, 1).unbatch().batch(self.batch_size)
                                else:
                                    self.train_tf_gen = DatasetFromSequenceClass(self.train_gen, len(self.train_gen), nEpochs, self.batch_size, dims=list(self.input_shape), out_dims=out_dims)
                                    self.val_tf_gen = DatasetFromSequenceClass(self.val_gen, len(self.val_gen), nEpochs, self.batch_size, dims=list(self.input_shape), out_dims=out_dims)
                                    self.train_tf_gen = self.train_tf_gen.shuffle(buffer_size=2000).prefetch(buffer_size=2000).cache()
                                    self.val_tf_gen = self.val_tf_gen.prefetch(buffer_size=2000).cache()
                                # if(self.normalize):
                                #     self.train_tf_gen = self.train_tf_gen.map(self.normalize)
                                #     self.val_tf_gen = self.val_tf_gen.map(self.normalize)
                                # history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), validation_data=self.val_tf_gen, validation_steps=len(self.val_gen), callbacks=callbacks, verbose=0)
                                # performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=1)
                                #TODO Temp for SIXRAY
                                # history = model.fit(self.train_tf_gen,epochs=epochs, steps_per_epoch=len(self.train_gen), callbacks=callbacks, verbose=0)
                                # performance = model.evaluate(self.val_tf_gen, steps=len(self.val_gen),verbose=1)
                                training =  self.train_tf_gen
                                testing =  self.val_tf_gen
                            history = None
                            performance = None
                            if(self.type_problem=='ss'):
                                history = model.fit(training,steps_per_epoch=len(self.train_gen)//self.batch_size,epochs=epochs,callbacks=callbacks, verbose=1)
                                performance = model.evaluate(testing,steps=len(self.val_gen)//self.batch_size,verbose=1)
                            else:
                                history = model.fit(training,epochs=epochs,callbacks=callbacks, verbose=1)
                                performance = model.evaluate(testing,verbose=1)
                            # Proxy score not used
                            # X_batch, Y_batch = next(iter(self.val_gen))
                            # b_j = get_batch_jacobian(model,X_batch)
                            # b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                            # #Proxy score
                            # sc = eval_score(b_j_2,Y_batch)
                            sc = 0
                        else:
                            history = model.fit(**fit_params)
                            performance = model.evaluate(self.x_test, self.y_test, batch_size=self.batch_size,verbose=1)
                            # Proxy score not used
                            #b_j = get_batch_jacobian(model,self.x_val)
                            #b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                            ##Proxy score
                            #sc = eval_score(b_j_2,self.y_val)
                        # if(self.gen_to_tf_data):
                        #     del self.train_tf_gen
                        #     del self.val_tf_gen
                        #     del training
                        #     del testing
                            
                except Exception as e:
                    # print(e)
                    print('Broken model')
                    performance = self._handle_broken_model(model, e)
                finally:
                    if(self.gen_to_tf_data and self.train_tf_gen):
                            del self.train_tf_gen
                            del self.val_tf_gen
                            del training
                            del testing
                
                # avg_similarity = 0
                # if(len(self.generation_members)>1):
                #     s1 = x
                #     for arch in self.generation_members:   
                #         sm=difflib.SequenceMatcher(None,s1,arch)
                #         similarity = sm.ratio()
                #         avg_similarity+=similarity
                #     avg_similarity = avg_similarity/(len(self.generation_members))
                # v = min(performance[0],4)
                if(self.type_problem=='classification'):
                    v = performance[0]
                else:
                    v = performance[1] + performance[2]
                    v = 2-v
            try:
                if(self.num_generations>40):
                    print('Skip saving')
                    # if(self.i_model/self.pop_size>self.num_generations-4):
                    #     model.save("model-{}.h5".format(self.i_model))
                else:
                    model.save("model-{}.h5".format(self.i_model))
            except:
                print('MODEL {} COUNT NOT BE SAVED!'.format(self.i_model))
                pass
            perf = dict()
            level_of_complexity/=10
            out['LC'] = level_of_complexity
            perf['LC'] = level_of_complexity
            # if(level_of_complexity<0.60):
            #     level_of_complexity = 2-level_of_complexity
            print(self.i_model,level_of_complexity, v)
            print('+++++++++++++++++++')
            # level_of_complexity = max(0.62, level_of_complexity)
            self.i_model+=1
            # v_acc = 1- performance[1]
            # out["F"] = [v,level_of_complexity, avg_similarity]
            out["F"] = [v,level_of_complexity]
            perf["F"] = [v,level_of_complexity]
            # out["G"] = [3 - x]
            # CIFAR 10 - 93.22
            # out["G"] = [v - 1.5]
            # CITY 
            print(v)
            out["G"] = [v - 1.2]
            perf["G"] = [v - 1.2]
            # out["G"] = [v+80]
            # For NASWOT 600 SYNFLOW 80
            # out["G"] = [v+600]
            # out["F"] = [v,v_acc]
            # perf["F"] = [v,v_acc]
            # CIFAR 10 - 93.22
            # perf["G"] = [v - 1.5]
            # perf["G"] = [3 - x]
            # out["acc"] = performance
            # perf["acc"] = performance
            out["M"] = "model-{}".format(self.i_model)
            perf["M"] = "model-{}".format(self.i_model)
            self.generation_members.append(x_list)
            self.generation_performances.append(perf)
            # del self.train_tf_gen
            # del self.val_tf_gen

    def get_filtered_con_layer(self,i,i_layer,all_p_c):
            # Sequential 
            if(i==0):
                return int(len(all_p_c[i_layer])/2)
            else:
                # Seq + super skip (Skip to bottom 1 and 2)
                if(i<3):
                    return int(len(all_p_c[i_layer])/2) + i 
        #             return -1
                #Skip 1-3 layers
                elif(i<6):
                    return int((len(all_p_c[i_layer+(i-2)])/2))
                #Skip 2-3 layers + sequential
                elif(i<8):
                    if(i_layer<len(all_p_c)-6):
                        return int(len(all_p_c[i_layer])/2) + int((len(all_p_c[i_layer+(i-4)])/2))
                    else:
                        if(i==6):
                            return 3
                        else:
                            return 7
                #Inception connection - seq + skip 1 + skip 2 + skip 5
                elif(i<9):
                    if(i_layer<len(all_p_c)-5):
                        return int(len(all_p_c[i_layer])/2) + int((len(all_p_c[i_layer+(i-3)])/2)) + int((len(all_p_c[i_layer+(i-6)])/2)) + int((len(all_p_c[i_layer+(i-7)])/2))
                    else:
                        # Can't inception, so next best is dense connect
                        return max(all_p_c[i_layer])
                #dense block of next 5
                else:
                    if(i_layer<len(all_p_c)-5):
                        return int(len(all_p_c[i_layer])/2) + int((len(all_p_c[i_layer+(i-8)])/2)) + int((len(all_p_c[i_layer+(i-7)])/2)) + int((len(all_p_c[i_layer+(i-6)])/2)) + int((len(all_p_c[i_layer+(i-5)])/2))
                    else:
                        # Can't dense 5, so next best is all dense - the last because of i8 repetition
                        return all_p_c[i_layer][-2:][0]
    

    def decode(self, genome):
        # print([genome[8 + i*self.convolution_layer_size] for i in range(self.convolution_layers) if genome[i*convolution_layer_size]==1])
        if not self.is_compatible_genome(genome):
            raise ValueError("Invalid genome for specified configs")

        active_layers = len([0 for i in range(self.convolution_layers) if genome[i*self.convolution_layer_size]==1])
        p_c_individual = self.p_c_filtered[-active_layers:]

        cons = [genome[self.convolutional_id_to_param['connections'] + i*self.convolution_layer_size] for i in range(self.convolution_layers) if genome[i*self.convolution_layer_size]==1]
        cons = [p_c_individual[i_l][con] if(len(p_c_individual[i_l])>con) else p_c_individual[i_l][0] for i_l,con in enumerate(cons)]
        lays = []
        x = None
        dim = 0
        offset = 0
        optim_offset = 0
        if self.convolution_layers > 0:
            dim = min(self.input_shape[:-1])  # keep track of smallest dimension
        input_layer = True
        dims = []
        gateways = dict()
        temp_features = self.max_filters
        if(self.double_up):
            # temp_features = int(self.max_filters/self.convolution_layers)
            temp_features = 32
        features = dict()
        naswot_outputs = []

        x = Input(shape=self.input_shape,dtype=tf.float32)
        add_layer(cons,lays, x, 0)
        for i in range(self.convolution_layers):
            if genome[offset]:
                if input_layer:
                    if(not self.smaller_ss):
                        temp_features = genome[offset + self.convolutional_id_to_param['num filters']]
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    x =  Convolution2D(
                        temp_features, (temp_kernel, temp_kernel),
                        padding="same",
                        # V 7 new
                        # kernel_initializer = 'he_normal',
                        kernel_initializer = tf.keras.initializers.GlorotNormal(),
                        bias_initializer = tf.keras.initializers.Constant(0.1),
                        # V12
#                         input_shape=self.input_shape,
                        activation=self.activation[genome[offset + self.convolutional_id_to_param['activation']]],
                        # V15 added regularisation
                        # kernel_regularizer=l2(0.005), bias_regularizer=l2(0.005)
                    )
                    lays.append(x)
                    input_layer = False
                else:
                    if(not self.smaller_ss):
                        if(self.type_problem =='classification'):
                            temp_features = genome[offset + self.convolutional_id_to_param['num filters']]
                        else:
                            temp_features = int(min(features[list(features.keys())[-1]],genome[offset + self.convolutional_id_to_param['num filters']]))
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    x = Convolution2D(
                        temp_features, (temp_kernel, temp_kernel),
                        padding="same",
                        kernel_initializer = 'he_normal',
                        activation=self.activation[genome[offset + self.convolutional_id_to_param['activation']]],
                        #V15 added regularisation
                        # kernel_regularizer=l2(0.005), bias_regularizer=l2(0.005)
                    )
                    lays.append(x)
                # if genome[offset + self.convolutional_id_to_param['batch normalisation']]:
                #     x = BatchNormalization()
                #     add_layer(cons,lays,x,len(lays))
                if(self.batch_norm):
                    x = BatchNormalization()
                    add_layer(cons,lays,x,len(lays))
                # x = Activation(self.activation[genome[offset + 4]])
                # add_layer(cons,lays,x,len(lays))
                #Append the gateway to layer for skip connection BEFORE pooling
                # if(not self.skip_op[genome[offset+7]]=='none'):
                #     gateways[offset]=((dim,x))
                max_pooling_type = genome[offset + self.convolutional_id_to_param['max pooling']]
                if max_pooling_type == 1 and dim >= 4:
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    if(self.double_up and temp_features<self.max_filters):
                        temp_features *= 2
                    # x = MaxPooling2D(pool_size=(2, 2), padding="same")
                    x = Convolution2D(temp_features, (temp_kernel,temp_kernel), strides=2,  #V15 added regularisation 
                        # kernel_regularizer=l2(0.005), bias_regularizer=l2(0.005),
                        padding="same")
                    add_layer(cons,lays,x,len(lays))
                    dim /= 2
                # Added dropout 2021/12/20
                if(self.dropout):
                    x = Dropout(0.2)
                    add_layer(cons,lays,x,len(lays))
            dims.append(dim)
            features[i] = temp_features
            dim = int(math.ceil(dim))
            if(i<self.convolution_layers-1):
                offset += self.convolution_layer_size
            else:
                optim_offset = offset + self.convolution_layer_size
        # x = Convolution2D(temp_features,(3,3),padding="same", kernel_initializer='he_uniform')
        # add_layer(cons,lays,x,len(lays))
        # level_of_compression = np.prod(x.get_shape()[1:])
        #level_of_compression is limited to 10 instead of 5 in the original MONCAE paper!
        # level_of_compression = min(math.log(level_of_compression,10),10)
        # TODO TEMP disbaled loc
        level_of_compression = 10
        needed_reductions = [i-2 for i,temp_dim in enumerate(dims) if(math.ceil(temp_dim)!=math.floor(temp_dim))]
        if(self.type_problem=='classification'):
            x = GlobalAveragePooling2D()
            add_layer(cons,lays,x,len(lays))
        if(self.type_problem=='ss' or self.type_problem=='ae'):
            #Reset the offset
            for i in reversed(range(self.convolution_layers)):
                #Done to fix shape when 14->7-> 4 => 4->8->16->14
                if(not(dim in dims) and ((dim-2)*2 in dims or (not(dim*2 in dims) and (dim-2)*2==min(self.input_shape[:-1])))):
                    x = Convolution2D(features[i],(3,3))
                    add_layer(cons,lays,x,len(lays))
                    dim-=2
                if genome[offset]:
                    skipped = False
                    temp_kernel = genome[offset + self.convolutional_id_to_param['kernel_size']]
                    max_pooling_type = genome[offset + self.convolutional_id_to_param['max pooling']]
                    x = Convolution2D(
                        features[i], (temp_kernel, temp_kernel),
                        padding="same",
                        # kernel_initializer = 'he_normal',
                        # V7
                        kernel_initializer = tf.keras.initializers.GlorotNormal(),
                        bias_initializer = tf.keras.initializers.Constant(0.1),
                        activation=self.activation[genome[offset + self.convolutional_id_to_param['activation']]],
                    )
                    add_layer(cons,lays,x,len(lays))
                    # if(not self.type_problem=='ss'):
                    #     x = Activation(self.activation[genome[offset + 4]])
                    #     add_layer(cons,lays,x,len(lays))
                    if (((dim*2 in dims or (dim*2)-2 in dims or (dim*4)-2 or dim==(int(min(self.input_shape[:-1]))/2))) and dim<min(self.input_shape[:-1])):
                        x = UpSampling2D((2, 2))
                        # x = Conv2DTranspose(features[i], (temp_kernel,temp_kernel), strides= 2, padding="same")
                        add_layer(cons,lays,x,len(lays))
                        self.last_upsampling_index = len(lays)-1
                        dim*=2
                if(dim>max(self.input_shape)):
                    import pdb
                    pdb.set_trace()
                offset -= self.convolution_layer_size
            if(not self.type_problem=='ss'):
                x = Convolution2D(self.input_shape[-1], (genome[self.convolutional_id_to_param['kernel_size']],genome[self.convolutional_id_to_param['kernel_size']]), activation=self.activation[genome[self.convolutional_id_to_param['activation']]], padding="same", )
                add_layer(cons,lays,x,len(lays))
            else:
                if(self.SYNFLOW):
                    x = Convolution2D(self.n_classes, self.input_shape[-1], activation=None, padding="same",)
                else:
                    x = Convolution2D(self.n_classes, self.input_shape[-1], activation="softmax", padding="same",)
                add_layer(cons,lays,x,len(lays))
        #Clear connections
        dirty_cons = None
        clean_cons = None
        # print(lays)
        try:
            if(self.type_problem=='ss'):
                add_unet_cons(cons,lays)
            dirty_cons = decode_connections(cons,len(cons))
            clean_cons = clear_cons(dirty_cons,len(cons))
        except Exception as ex:
            print(cons,len(cons))
            print('Failed connections!')
            print(ex)
            
        operations = []
        operations = self.decode_ops(operations,lays,clean_cons)
        if(self.type_problem=='classification'):
            ### Dense layer decoding (classification only)
            
            offset = optim_offset
            has_a_dense_active = False
            for i in range(self.dense_layers):
                if genome[offset]:
                    temp_nodes = genome[offset + self.dense_id_to_param['num filters']]
                    x =  Dense(
                        temp_nodes,
                        # kernel_initializer = 'he_normal',
                        # V7
                        kernel_initializer = tf.keras.initializers.GlorotNormal(),
                        bias_initializer = tf.keras.initializers.Constant(0.1),
                        activation=self.activation[genome[offset + self.dense_id_to_param['activation']]]
                    )
                    operations.append(x(operations[-1]))
                    dr =  Dropout(0.2)
                    operations.append(dr(operations[-1]))


                    has_a_dense_active = True
                offset+= self.dense_layer_size
            if(not has_a_dense_active):
                x = Dense(self.n_classes)
                operations.append(x(operations[-1]))
            if(self.SYNFLOW):
                x = Dense(self.n_classes)
            elif(self.multilabel == True):
                x = Dense(self.n_classes, activation='sigmoid')
            else:
                x = Dense(self.n_classes, activation='softmax')
            operations.append(x(operations[-1]))
            optim_offset = offset
        outs = list()  
        # import pdb
        # pdb.set_trace()
        if(self.NASWOT):
            outs = [op for op in operations if 'relu' in op.name or 'Relu' in op.name]
            if(len(outs)<1):
                outs=operations[-1]
        else:
            outs=operations[-1]
        model = Model(operations[0],outs)
        # TODO changed from binary_crossentropy
        metrics = ["accuracy"]
        # print(self.type_problem, self.TRAIN_WITH_LOGITS)
        if(self.TRAIN_WITH_LOGITS):
            loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        elif(self.type_problem=='ss' or self.type_problem=='ae'):
            loss = 'sparse_categorical_crossentropy'
        else:
            loss = 'categorical_crossentropy'
        if(self.type_problem =='ss'):
            # loss = dice_loss
            loss = 'categorical_crossentropy'
            metrics += [MeanIOUWrapper(num_classes=self.n_classes)]
        #TODO CHANGE THAT!!!!
        if(True):
            id_to_name={
                0:'Gun',
                1:'Knife',
                2:'Wrench',
                3:'Pliers',
                4:'Scissors'
            }
            # metrics+= [tf.keras.metrics.Precision(class_id=idx, name='precision_{}'.format(id_to_name[idx])) for idx in range(len(id_to_name))]
        if(not self.NASWOT and not self.SYNFLOW):
            model.compile(loss=loss,
                        optimizer=self.optimizer[genome[optim_offset]],
                        metrics=metrics)
        num_params = np.sum([np.prod(l.output_shape[1:],dtype=np.int64) for l in model.layers], dtype=np.int64)
        # level_of_complexity = min(math.log(int(model.count_params()),10),10)
        level_of_complexity = min(math.log(int(num_params),10),10)
        return model,level_of_compression,level_of_complexity
    
    def decode_ops(self,operations=list(),lays=list(),cons=list()):
        index_fixer = len(operations)
        operations = operations.copy()
        for index,con in enumerate(cons.transpose()):
            #First layer in cell
            if(index==0):
                operations.append(lays[0])
            else:
                import pdb
                pdb.set_trace()
                #TODO FIND MORE OPTIMIAL WAY OF DOING THIS!!!
                nz = np.nonzero(con)[0]
                if(len(nz)>1):
                    shapes = {operations[layer+index_fixer].shape[-1] for layer in nz}
                    k_shapes = {operations[layer+index_fixer].shape[1] for layer in nz}
                    full_shapes = {(operations[layer+index_fixer].shape[1],operations[layer+index_fixer].shape[3]) for layer in nz}
                    can_add = True

                    if(len(full_shapes)==1):
                        op = add([operations[layer+index_fixer] for layer in nz])
                    else:
                        # print('Gonna try')
                        adjustment_ops = []
                        layers_for_addition = [operations[layer+index_fixer] for layer in nz]
                        lowest_dim_ind = 0
                        lowest_dim = layers_for_addition[0].shape[1]
                        if(self.type_problem=='ss'):
                            lowest_dim = layers_for_addition[len(layers_for_addition)-1].shape[1]
                            lowest_dim_ind = len(layers_for_addition)-1
                        else:
                            for i_l,l in enumerate(layers_for_addition):
                                if(l.shape[1]<lowest_dim):
                                    lowest_dim = l.shape[1]
                                    lowest_dim_ind = i_l
                        for i_l,l_to_add in enumerate(layers_for_addition):
                            adjust_op = l_to_add
                            tries = 0
                            
                            while(adjust_op.shape[1] != layers_for_addition[lowest_dim_ind].shape[1] and can_add):
                                if(adjust_op.shape[1]<layers_for_addition[lowest_dim_ind].shape[1]):
                                    adjust_op = UpSampling2D((2, 2))(adjust_op)
                                else:
                                    adjust_op = Convolution2D(256,strides=2, padding="same")(adjust_op)
                                tries+=1
                                if(tries>10):
                                    can_add=False
                                    print('Cannot add {} and {} and {}'.format(l_to_add.shape, adjust_op.shape ,layers_for_addition[lowest_dim_ind].shape))
                            if(i_l!=lowest_dim_ind):
                                adjust_op = Convolution2D(layers_for_addition[lowest_dim_ind].shape[-1],kernel_size=(1, 1), padding="same")(adjust_op)
                            adjustment_ops.append(adjust_op)
                        if(can_add):
                            op = add([adj_op for adj_op in adjustment_ops])
                        # print('We did it?')
                    # elif((len(k_shapes)==1) and len(shapes)!=1):
                    #     desired_shape = operations[-1].shape[-1]
                    #     op = concatenate([operations[layer+index_fixer] for layer in nz])
                    #     #Fix shape with identity
                    #     op = Convolution2D(desired_shape,(1,1),padding="same")(op)
                    if(can_add == False):
                        operations.append(lays[index](operations[-1]))
                        continue
                    else:
                        operations.append(lays[index](op))
                elif(len(nz)==1):
                    operations.append(lays[index](operations[-1]))
                else:
                    print('======ERRORRORORORR ========')
#                     pdb.set_trace()
                    continue
        return operations
    
    def decode_model_genome(self, genome):
        x = genome
        genome = [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.convolution_layer_size)])])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.layer_params)]
        if(self.type_problem== 'classification'):
            genome+= [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.dense_layer_size)])])for i_layer in range(self.dense_layers) for i_param,param in enumerate(self.dense_layer_shape)]
            
        genome+= [math.floor(x[-1])]
        
        #Decode model
        try:
            model,level_of_compression,level_of_complexity = self.decode(genome)
        except:
            raise Exception('Model could not be decoded')
        return model
        
    
    def convParam(self, i):
        key = self.convolutional_layer_shape[i]
        return self.layer_params[key]
    
    def denseParam(self, i):
        key = self.dense_layer_shape[i]
        return self.layer_params[key]
        
    def is_compatible_genome(self, genome):
        expected_len = (self.convolution_layers * self.convolution_layer_size) + (self.dense_layers * self.dense_layer_size) + 1
        if len(genome) != expected_len:
            return False
        ind = 0
        for i in range(self.convolution_layers):
            for j in range(self.convolution_layer_size):
                if genome[ind + j] not in self.convParam(j):
                    return False
            ind += self.convolution_layer_size
            
        for i in range(self.dense_layers):
            for j in range(self.dense_layer_size):
                if genome[ind + j] not in self.denseParam(j):
                    return False
            ind += self.dense_layer_size
            
        if genome[ind] not in range(len(self.optimizer)):
            return False
        
        return True
    
    def _handle_broken_model(self, model, error):
        self.skip_model = True
        print('================')
        print('Number of parameters:', str(model.count_params()))
        print('================')

        n = self.n_classes
        # v2 Added loss 10 times more for models out of score to make them infavourable
        # performance = [log_loss(np.concatenate(([1], np.zeros(n - 1))), np.ones(n) / n)*10, math.log((self.input_shape[1]*self.input_shape[1]),10), 1]
        performance = [10.0] + [K.epsilon()] * (len(model.metrics)-1)
        # del model
        gc.collect()

        # if K.backend() == 'tensorflow':
        #     K.clear_session()
        #     #Changed from tensorflow
        #     ops.reset_default_graph()

        print('An error occurred and the model could not train!')
        print(('Model assigned poor score. Please ensure that your model'
               'constraints live within your computational resources.'))
        return performance
    
    def set_objective(self, metric):
        """
        Set the metric for optimization. Can also be done by passing to
        `run`.

        Args:
            metric (str): either 'acc' to maximize classification accuracy, or
                    else 'loss' to minimize the loss function
        """
        if metric not in ['loss', 'hvi']:
            raise ValueError(('Invalid metric name {} provided - should be'
                              '"accuracy" or "loss"').format(metric))
        self._metric = metric
        self._objective = "max" if self._metric == "hvi" else "min"
        #TODO currently loss and accuracy
        self._metric_index = 0 
        self._metric_op = METRIC_OPS[self._objective == 'max']
        self._metric_objective = METRIC_OBJECTIVES[self._objective == 'max']

    
    

        
def add_layer(cons, lays, layer,pos):
    size = len(cons)
    con = int(2**((size)-(pos+1)))
    cons.insert(pos,con)
    lays.insert(pos,layer)
    for i in range(0,pos):
        positional = 2**(size-(i+1))
        if(cons[i]<positional):
            if(positional==1):
                cons[i]==0
            elif(cons[i]>=(positional/2)):
                cons[i] = int(cons[i]- (positional/2))
            cons[i] = cons[i]+positional
    return cons, lays

def decode_connections(cons,cell_size):
    bin_cons = list()
    for i_con,con in enumerate(cons):
        overflow = 2**(cell_size)
        tries = 0
        enough = False
        while(cons[i_con]>=overflow and not enough):
            cons[i_con] -= overflow
            tries+=1
            if(tries>cell_size*10):
                enough = True

    for con in cons:
        binarised = bin(con)[2:]
        bin_cons.append([int(digit) for digit in eval("f\"{binarised:0>"+str(cell_size)+"}\"")])
    # print(bin_cons)
#     bin_cons.append([0] * cell_size)
    return np.array(bin_cons, dtype=object)

def clear_cons(dirty_cons,cell_size):
    # Disabled v4

    # b = np.ones(cell_size)
    # np.fill_diagonal(dirty_cons[:,1:], b)
    clean_cons = np.triu(dirty_cons, k=1)
    for i in range(1,len(clean_cons)):
        # Added v5 to prevent disconnected layer which was connected but in a wrong way
        if(1 not in clean_cons[i,:] and 1 in dirty_cons[i,:]):
            clean_cons[i,i+1] = 1
        if(1 not in clean_cons[:,i]):
            clean_cons[0,i] = 1
    return clean_cons


def add_unet_cons(cons,lays):
    down_ids = [i-1 for i,l in enumerate(lays) if(hasattr(l, 'strides') and l.strides[0]==2)]
    up_ids = [len(lays)-i-1 for i,l in enumerate(lays) if 'up' in l.name]
    for down_id, up_id in zip(down_ids,up_ids):
        cons[down_id] += 2**up_id