
import numpy as np
from pymoo.util.misc import stack
from pymoo.model.problem import Problem

    
from tqdm import tqdm
from PIL import Image
import glob
import numpy as np
import random as rand
import math

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Dropout, Flatten
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import add,concatenate,Dot
from tensorflow.keras import Input,Model, metrics
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.framework import ops
from sklearn.metrics import log_loss
import gc
import tensorflow as tf
from losses import dice_loss
# import logging
# logging.getLogger("tensorflow").setLevel(logging.ERROR)
import operator
METRIC_OPS = [operator.__lt__, operator.__gt__]
METRIC_OBJECTIVES = [min, max]

class PymooGenome(Problem):
    
    def __init__(self, 
                 max_conv_layers, 
                 max_filters,
                 input_shape, n_classes,
                 batch_size=256,
                 batch_normalization=True,
                 dropout=True, 
                 max_pooling=True,
                 optimizers=None,
                 activations=None,
                 skip_ops=None,
                 type_problem='autoencoder',
                 TRAIN_WITH_LOGITS = False,
                ):
        if max_conv_layers < 1:
            raise ValueError(
                "At least one conv layer is required for AE to work"
            )
        if max_filters > 0:
            filter_range_max = int(math.log(max_filters, 2)) + 1
        else:
            filter_range_max = 0
        self.optimizer = optimizers or [
            'rmsprop',
            'adagrad',
            'adadelta',
            'adam'
        ]
        self.activation = activations or [
            'sigmoid',
            'relu',
            None
        ]
        self.skip_op = skip_ops or[
            'none',
            'add',
            'concatenate'
        ]

        self.convolutional_layer_shape = [
            "active",
            "num filters",
            "kernel_size",
            "batch normalization",
            "activation",
            "dropout",
            "max pooling",
            "skip_op",
            "connections"
        ]
        self.layer_params = {
            "active": [0, 1],
            "num filters": [2**i for i in range(2, filter_range_max)],
            #Added after paper release
            "kernel_size": [1,3,5,7],
            "batch normalization": [0, (1 if batch_normalization else 0)],
            "activation": list(range(len(self.activation))),
            "dropout": [(i if dropout else 0) for i in range(11)],
            "max pooling": list(range(3)) if max_pooling else 0,
            #In development
            "skip_op": list(range(len(self.skip_op))),
            "connections": [i for i in range(1,2**(max_conv_layers-1))]
        }

        self.convolution_layers = max_conv_layers
        self.convolution_layer_size = len(self.convolutional_layer_shape)
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.batch_size = batch_size
        self.type_problem = type_problem
        self.TRAIN_WITH_LOGITS = TRAIN_WITH_LOGITS
        self.i_model = 0
        

        # conv_layers*conv_layer_size + 1 variables
        # 3 objectives = loss, compression, complexity
        # ? constraints
        number_of_variables = (self.convolution_layers*self.convolution_layer_size) +1
        number_of_objectives = 2
        close_to_one = 0.9999999999999
        super().__init__(n_var=number_of_variables,
                         n_obj=number_of_objectives,
                         n_constr=0,
                         xl=np.array([0 for i_layer in range(self.convolution_layers) for param in self.layer_params] + [0]),
                         xu=np.array([len(self.layer_params[param])-(1-close_to_one) for i_layer in range(self.convolution_layers) for param in self.layer_params] + [len(self.optimizer)-(1-close_to_one)]),
                         elementwise_evaluation=True)
        
    def feed_data(self,
                  train_with_gen,
                  dataset,
                  num_generations,
                  pop_size,
                  pss,
                  metric='loss',
                  epochs = 5
                 ):
        #Previously from NE
        self.train_with_gen = train_with_gen
        self.dataset = dataset
        self.num_generations = num_generations
        if(train_with_gen):
            self.train_with_gen = True
            self.train_gen,self.val_gen = dataset
        self.generation_times = []
        self.generation_performances = []
        self.generation_archive = []
        self.generation_members = []
        self.pss = pss
        self.pop_size = pop_size
        if(self.pop_size == None and self.pss != None):
            raise ValueError("Please specify pop_size if you use PSS")
#         self.set_objective(metric)
        # If no validation data is given set it to None
        if(not self.train_with_gen):
            if len(dataset) == 2:
                (self.x_train, self.y_train), (self.x_test, self.y_test) = dataset
                self.x_val = None
                self.y_val = None
            else:
                (self.x_train, self.y_train), (self.x_test, self.y_test), (self.x_val, self.y_val) = dataset
            self.x_train_full = self.x_train.copy()
            self.y_train_full = self.y_train.copy()
            
        self.set_objective(metric)
        self.epochs = epochs
        
    
    def get_problem(self):
        return self.problem
    
    def _evaluate(self, x, out, *args, **kwargs):
         #Real version
        
        genome = [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.convolution_layer_size)])])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.layer_params)]
        genome+= [math.floor(x[-1])]
        
        #Decode model
        model,level_of_compression,level_of_complexity = None,None,None
        try:
            model,level_of_compression,level_of_complexity = self.decode(genome)
        except Exception as e:
            print(e)
            
        #Initialise performance metrics list
        performance = []
        
        #Define callbacks
        callbacks = [
                EarlyStopping(monitor='val_loss', patience=1, verbose=0)
            ]
        
        epochs = self.epochs
        
        #Define fir parameters
        if(not self.train_with_gen):
            fit_params = {
                'x': self.x_train_full,
                'y': self.y_train_full,
                'validation_split': 0.1,
                'batch_size':self.batch_size,
                'shuffle':True,
                'steps_per_epoch': int(len(self.x_train_full)/self.batch_size),
                'epochs': epochs,
                'verbose': 0,
                'callbacks': callbacks
            }
            if self.x_val is not None:
                fit_params['validation_data'] = (self.x_val, self.y_val)
                
        #Initialise proxy score
        sc = 0
        try:
            if(self.pss):
                igen = int(self.i_model/self.pop_size)
                history = model.fit(self.train_gen[igen%self.pss],epochs=epochs, validation_data=self.val_gen, callbacks=callbacks, verbose=0)
                performance = model.evaluate(self.val_gen, verbose=1)
                
                
                # Proxy score not used
                # X_batch, Y_batch = next(iter(self.val_gen))
                # b_j = get_batch_jacobian(model,X_batch)
                # b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                # #Proxy score
                # sc = eval_score(b_j_2,Y_batch)
                sc = 0
            else:
                if(self.train_with_gen):
                    history = model.fit(self.train_gen,epochs=epochs, validation_data=self.val_gen, callbacks=callbacks, verbose=0)
                    performance = model.evaluate(self.val_gen, verbose=1)
                    # Proxy score not used
                    # X_batch, Y_batch = next(iter(self.val_gen))
                    # b_j = get_batch_jacobian(model,X_batch)
                    # b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                    # #Proxy score
                    # sc = eval_score(b_j_2,Y_batch)
                    sc = 0
                else:
                    history = model.fit(**fit_params)
                    performance = model.evaluate(self.x_test, self.y_test, verbose=0)
                    # Proxy score not used
                    #b_j = get_batch_jacobian(model,self.x_val)
                    #b_j_2 = b_j.numpy().reshape(b_j.shape[0], -1)
                    ##Proxy score
                    #sc = eval_score(b_j_2,self.y_val)
                    
        except Exception as e:
            print(e)
            performance = self._handle_broken_model(model, e)
            
        
        
        v = min(performance[self._metric_index],4)
        
        model.save("model-{}.h5".format(self.i_model))

        
        self.i_model+=1
        out["F"] = [v,level_of_complexity]
        out["M"] = "model-{}".format(self.i_model)
    

    def decode(self, genome):
        # print([genome[8 + i*self.convolution_layer_size] for i in range(self.convolution_layers) if genome[i*convolution_layer_size]==1])
        if not self.is_compatible_genome(genome):
            raise ValueError("Invalid genome for specified configs")
        cons = [genome[8 + i*self.convolution_layer_size] for i in range(self.convolution_layers) if genome[i*self.convolution_layer_size]==1]
        lays = []
        x = None
        dim = 0
        offset = 0
        optim_offset = 0
        if self.convolution_layers > 0:
            dim = min(self.input_shape[:-1])  # keep track of smallest dimension
        input_layer = True
        dims = []
        gateways = dict()
        temp_features = 0
        features = dict()

        x = Input(shape=self.input_shape,dtype=tf.float32)
        add_layer(cons,lays, x, 0)
        for i in range(self.convolution_layers):
            if genome[offset]:
                if input_layer:
                    temp_features = genome[offset + 1]
                    temp_kernel = genome[offset + 2]
                    x =  Convolution2D(
                        temp_features, (temp_kernel, temp_kernel),
                        padding='same',
                        # V12
                        kernel_initializer = 'he_normal',
#                         input_shape=self.input_shape,
                        activation=self.activation[genome[offset + 4]]
                    )
                    lays.append(x)
                    input_layer = False
                else:
                    temp_features = int(min(features[list(features.keys())[-1]],genome[offset + 1]))
                    temp_kernel = genome[offset + 2]
                    x = Convolution2D(
                        temp_features, (temp_kernel, temp_kernel),
                        padding='same',
                        activation=self.activation[genome[offset + 4]]
                    )
                    lays.append(x)
                if genome[offset + 3]:
                    x = BatchNormalization()
                    add_layer(cons,lays,x,len(lays))
                # x = Activation(self.activation[genome[offset + 4]])
                # add_layer(cons,lays,x,len(lays))
                #Append the gateway to layer for skip connection BEFORE pooling
                if(not self.skip_op[genome[offset+7]]=='none'):
                    gateways[offset]=((dim,x))
                max_pooling_type = genome[offset + 6]
                if max_pooling_type == 1 and dim >= 4:
                    x = MaxPooling2D(pool_size=(2, 2), padding="same")
                    add_layer(cons,lays,x,len(lays))
                    dim /= 2
            dims.append(dim)
            features[i] = temp_features
            dim = int(math.ceil(dim))
            if(i<self.convolution_layers-1):
                offset += self.convolution_layer_size
            else:
                optim_offset = offset + self.convolution_layer_size
        x = Convolution2D(temp_features,(3,3),padding='same')
        add_layer(cons,lays,x,len(lays))
        # level_of_compression = np.prod(x.get_shape()[1:])
        #level_of_compression is limited to 10 instead of 5 in the original MONCAE paper!
        # level_of_compression = min(math.log(level_of_compression,10),10)
        # TODO TEMP disbaled loc
        level_of_compression = 10
        needed_reductions = [i-2 for i,temp_dim in enumerate(dims) if(math.ceil(temp_dim)!=math.floor(temp_dim))]
        #Reset the offset
        for i in reversed(range(self.convolution_layers)):
            #Done to fix shape when 14->7-> 4 => 4->8->16->14
            if(not(dim in dims) and ((dim-2)*2 in dims or (not(dim*2 in dims) and (dim-2)*2==min(self.input_shape[:-1])))):
                x = Convolution2D(features[i],(3,3))
                add_layer(cons,lays,x,len(lays))
                dim-=2
            if genome[offset]:
                skipped = False
                max_pooling_type = genome[offset + 6]
                x = Convolution2D(
                    features[i], (genome[offset+2], genome[offset+2]),
                    padding='same',
                    activation=self.activation[genome[offset + 4]],
                )
                add_layer(cons,lays,x,len(lays))
                # if(not self.type_problem=='ss'):
                #     x = Activation(self.activation[genome[offset + 4]])
                #     add_layer(cons,lays,x,len(lays))
                if (((dim*2 in dims or (dim*2)-2 in dims or (dim*4)-2 or dim==(int(min(self.input_shape[:-1]))/2))) and dim<min(self.input_shape[:-1])):
                    x = UpSampling2D((2, 2))
                    add_layer(cons,lays,x,len(lays))
                    dim*=2
            if(dim>max(self.input_shape)):
                import pdb
                pdb.set_trace()
            offset -= self.convolution_layer_size
        if(not self.type_problem=='ss'):
            x = Convolution2D(self.input_shape[-1], (genome[2],genome[2]), activation=self.activation[genome[4]], padding='same', )
            add_layer(cons,lays,x,len(lays))
        else:
            x = Convolution2D(self.n_classes, self.input_shape[-1], activation="softmax", padding="same",)
            add_layer(cons,lays,x,len(lays))
        #Clear connections
        dirty_cons = None
        clean_cons = None
        try:
            dirty_cons = decode_connections(cons,len(cons))
            clean_cons = clear_cons(dirty_cons,len(cons))
        except:
            print('Failed cons!')
            import pdb
            pdb.set_trace()
        # print('===')
        # print(lays,clean_cons)
        # print('=======')
        #Decode ops
        operations = []
        operations = self.decode_ops(operations,lays,clean_cons)
        model = Model(operations[0],operations[-1])
        # TODO changed from binary_crossentropy
        metrics = ["accuracy"]
        if(self.TRAIN_WITH_LOGITS):
            loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        else:
            loss = 'sparse_categorical_crossentropy'
        if(self.type_problem =='ss'):
            loss = dice_loss
            metrics += [tf.keras.metrics.MeanIoU(self.n_classes)]
        model.compile(loss=loss,
                    optimizer=self.optimizer[genome[optim_offset]],
                    metrics=metrics)
        level_of_complexity = min(math.log(int(model.count_params()),10),10)
        return model,level_of_compression,level_of_complexity
    
    def decode_ops(self,operations=list(),lays=list(),cons=list()):
        index_fixer = len(operations)
        operations = operations.copy()
        for index,con in enumerate(cons.transpose()):
            #First layer in cell
            if(index==0):
                operations.append(lays[0])
            else:
                #TODO FIND MORE OPTIMIAL WAY OF DOING THIS!!!
                nz = np.nonzero(con)[0]
                if(len(nz)>1):
                    shapes = {operations[layer+index_fixer].shape[-1] for layer in nz}
                    k_shapes = {operations[layer+index_fixer].shape[1] for layer in nz}
                    full_shapes = {(operations[layer+index_fixer].shape[1],operations[layer+index_fixer].shape[3]) for layer in nz}
                    if(len(full_shapes)==1):
                        op = add([operations[layer+index_fixer] for layer in nz])
                    elif((len(k_shapes)==1) and len(shapes)!=1):
                        desired_shape = operations[-1].shape[-1]
                        op = concatenate([operations[layer+index_fixer] for layer in nz])
                        #Fix shape with identity
                        op = Convolution2D(desired_shape,(1,1),padding='same')(op)
                    else:
                        operations.append(lays[index](operations[-1]))
                        continue
                    operations.append(lays[index](op))
                elif(len(nz)==1):
                    operations.append(lays[index](operations[-1]))
                else:
                    print('======ERRORRORORORR ========')
#                     pdb.set_trace()
                    continue
        return operations
    
    def decode_model_genome(self, genome):
        x = genome
        genome = [(self.layer_params[param][math.floor(x[i_param + (i_layer*self.convolution_layer_size)])])for i_layer in range(self.convolution_layers) for i_param,param in enumerate(self.layer_params)]
        genome+= [math.floor(x[-1])]
        
        #Decode model
        try:
            model,level_of_compression,level_of_complexity = self.decode(genome)
        except:
            raise Exception('Model could not be decoded')
        return model
        
    
    def convParam(self, i):
        key = self.convolutional_layer_shape[i]
        return self.layer_params[key]
        
    def is_compatible_genome(self, genome):
        expected_len = self.convolution_layers * self.convolution_layer_size + 1
        if len(genome) != expected_len:
            return False
        ind = 0
        for i in range(self.convolution_layers):
            for j in range(self.convolution_layer_size):
                if genome[ind + j] not in self.convParam(j):
                    return False
            ind += self.convolution_layer_size
        if genome[ind] not in range(len(self.optimizer)):
            return False
        return True
    
    def _handle_broken_model(self, model, error):
        print('================')
        print('Number of parameters:', str(model.count_params()))
        print('================')
        del model

        n = self.n_classes
        # v2 Added loss 10 times more for models out of score to make them infavourable
        performance = [log_loss(np.concatenate(([1], np.zeros(n - 1))), np.ones(n) / n)*10, math.log((self.input_shape[1]*self.input_shape[1]),10)]
        gc.collect()

        if K.backend() == 'tensorflow':
            K.clear_session()
            #Changed from tensorflow
            ops.reset_default_graph()

        print('An error occurred and the model could not train!')
        print(('Model assigned poor score. Please ensure that your model'
               'constraints live within your computational resources.'))
        return performance
    
    def set_objective(self, metric):
        """
        Set the metric for optimization. Can also be done by passing to
        `run`.

        Args:
            metric (str): either 'acc' to maximize classification accuracy, or
                    else 'loss' to minimize the loss function
        """
        if metric not in ['loss', 'hvi']:
            raise ValueError(('Invalid metric name {} provided - should be'
                              '"accuracy" or "loss"').format(metric))
        self._metric = metric
        self._objective = "max" if self._metric == "hvi" else "min"
        #TODO currently loss and accuracy
        self._metric_index = 0 
        self._metric_op = METRIC_OPS[self._objective == 'max']
        self._metric_objective = METRIC_OBJECTIVES[self._objective == 'max']

    
    

        
def add_layer(cons, lays, layer,pos):
    size = len(cons)
    con = int(2**((size)-(pos+1)))
    cons.insert(pos,con)
    lays.insert(pos,layer)
    for i in range(0,pos):
        positional = 2**(size-(i+1))
        if(cons[i]<positional):
            if(positional==1):
                cons[i]==0
            elif(cons[i]>=(positional/2)):
                cons[i] = int(cons[i]- (positional/2))
            cons[i] = cons[i]+positional
    return cons, lays

def decode_connections(cons,cell_size):
    bin_cons = list()
    for i_con,con in enumerate(cons):
        overflow = 2**(cell_size)
        while con>=overflow:
            cons[i_con] -= overflow
    for con in cons:
        binarised = bin(con)[2:]
        bin_cons.append([int(digit) for digit in eval("f\"{binarised:0>"+str(cell_size)+"}\"")])
    # print(bin_cons)
#     bin_cons.append([0] * cell_size)
    return np.array(bin_cons, dtype=object)

def clear_cons(dirty_cons,cell_size):
    # Disabled v4

    # b = np.ones(cell_size)
    # np.fill_diagonal(dirty_cons[:,1:], b)
    clean_cons = np.triu(dirty_cons, k=1)
    for i in range(1,len(clean_cons)):
        # Added v5 to prevent disconnected layer which was connected but in a wrong way
        if(1 not in clean_cons[i,:] and 1 in dirty_cons[i,:]):
            clean_cons[i,i+1] = 1
        if(1 not in clean_cons[:,i]):
            clean_cons[0,i] = 1
    return clean_cons
    