from keras.layers import *
from custom_losses import *
from custom_layers import *
import keras.backend as K
import matplotlib.pyplot as plt
from keras.models import Model
from keras.optimizers import Adam
from keras.callbacks import *
from sklearn.metrics import accuracy_score
import json

class CAE():
    def __init__(self, input_shape, enc_conv_filters, enc_kernel_sizes, dec_conv_filters, dec_kernel_sizes, cae_activation, data_shape, weight_reg):
        self.batch_size = data_shape[0]
        self.input_features = Input(shape=input_shape,name="input_cae")
        for i in range(len(enc_conv_filters)):
            if i == 0:
                encoded = Conv2D(enc_conv_filters[i], enc_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", activation=cae_activation, input_shape=input_shape, name="enc_"+str(i))(self.input_features)
            else:
                encoded = Conv2D(enc_conv_filters[i], enc_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", activation=cae_activation, name="enc_"+str(i))(encoded)

                
        for i in range(len(dec_conv_filters)):
            if len(dec_conv_filters)==1:
                self.decoded = Conv2DTranspose(data_shape[-1], dec_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", name="dec_fin")(encoded)
            else:
                if i == 0:
                    decoded = Conv2DTranspose(dec_conv_filters[i], dec_kernel_sizes[i], kernel_regularizer = weight_reg, strides=(2,2), padding="same", name="dec_"+str(i-len(conv_filters)//2))(encoded)
                elif i==len(dec_conv_filters):
                    self.decoded = Conv2DTranspose(data_shape[-1], dec_kernel_sizes[i], kernel_regularizer = None, strides=(2,2), padding="same", name="dec_fin")(decoded)
                else:
                    decoded = Conv2DTranspose(dec_conv_filters[i], dec_kernel_sizes[i], kernel_regularizer = weight_reg, strides=(2,2), padding="same", name="dec_"+str(i-len(conv_filters)//2))(decoded)

                    
        self.cae = Model(self.input_features, self.decoded) 
        
    def compile(self, lambda_rec, lr):
        recon_loss = lambda_rec*tf.losses.mean_squared_error(K.flatten(self.input_features), K.flatten(self.decoded))
        self.cae.add_loss(recon_loss)
        self.cae.compile(Adam(lr = lr))

        self.cae.metrics_tensors.append(recon_loss)
        self.cae.metrics_names.append("reconstruction_loss")
        #self.cae.summary()

    def train(self, data, epochs, es_patience=10, lr_patience=5, decay=0.5,  min_lr=1e-6):
        earlystopper_cae = EarlyStopping(monitor='loss', patience=es_patience, verbose=1)
        reduce_lr_cae = ReduceLROnPlateau(monitor='loss', factor=decay, patience=lr_patience, cooldown=2, min_lr=min_lr,verbose=1)

        self.cae_history = self.cae.fit(data,
                    batch_size=self.batch_size,
                    epochs=epochs,
                    shuffle=False,
                    callbacks=[earlystopper_cae,reduce_lr_cae],
                    verbose=0)
        
    def get_layer(self,name):
        return self.cae.get_layer(name)
    
    def predict(self,data):
        return self.cae.predict(data)
        
    def save_logs(self, name, path="./"):
        with open(path+name+'.json', 'w') as fp:
            json.dump(str(self.cae_history.history), fp)

    def save_weights(self, name, path="./"):
        self.cae.save_weights(path+name)
            
class CAE_SE():
    def __init__(self, input_shape, enc_conv_filters, enc_kernel_sizes, dec_conv_filters, dec_kernel_sizes, cae_activation, data_shape, weight_reg=None):
        self.batch_size = data_shape[0]
        self.input_features = Input(shape=input_shape,name="input_cae_se")
        for i in range(len(enc_conv_filters)):
            if i == 0:
                encoded = Conv2D(enc_conv_filters[i], enc_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", activation=cae_activation, input_shape=input_shape, name="enc_"+str(i))(self.input_features)
            else:
                encoded = Conv2D(enc_conv_filters[i], enc_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", activation=cae_activation, name="enc_"+str(i))(encoded)

        self.z = Flatten(name="Z")(encoded)
        self.zc_layer = ZC_Layer(self.batch_size, kernel_regularizer = None, no_diag = True, name = "ZC")
        self.zc = self.zc_layer(self.z)
        zc_reshape = Reshape(encoded.shape.as_list()[1:],name="ZC_reshape")(self.zc)        

        for i in range(len(dec_conv_filters)):
            if len(dec_conv_filters)==1:
                self.decoded = Conv2DTranspose(data_shape[-1], dec_kernel_sizes[i], kernel_regularizer = None, strides=(2,2), padding="same", name="dec_fin")(zc_reshape)
            else:
                if i == 0:
                    decoded = Conv2DTranspose(dec_conv_filters[i], dec_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", name="dec_"+str(i-len(conv_filters)//2))(zc_reshape)
                elif i==len(dec_conv_filters):
                    self.decoded = Conv2DTranspose(data_shape[-1], dec_kernel_sizes[i], strides=(2,2), kernel_regularizer = None, padding="same", name="dec_fin")(decoded)
                else:
                    decoded = Conv2DTranspose(dec_conv_filters[i], dec_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", name="dec_"+str(i-len(conv_filters)//2))(decoded)

        self.cae_se = Model(self.input_features, self.decoded)

    def compile(self, lambda_rec, lambda_se, sigma, lambda_ca, lambda_bd, lambda_l2, num_classes, lr=1e-3, correntropy=True, bd_r = True):
        self.n_classes = num_classes
        recon_loss = lambda_rec*tf.losses.mean_squared_error(K.flatten(self.input_features), K.flatten(self.decoded))
        self.cae_se.add_loss(recon_loss)
        if correntropy:
            selfExp_loss = se_corr_loss(self.z,self.zc,self.batch_size,alpha=lambda_se,sigm=sigma)
        else:
            selfExp_loss = se_loss(self.z,self.zc,self.batch_size,alpha=lambda_se)
        self.cae_se.add_loss(selfExp_loss)
        if bd_r:
            c_a_reg_loss = c_a_reg(self.zc_layer.kernel,alpha=lambda_ca)
            bd_reg_loss = bd_reg(self.zc_layer.kernel,k=self.n_classes,alpha=lambda_bd)
            self.cae_se.add_loss(bd_reg_loss)
            self.cae_se.add_loss(c_a_reg_loss)
        else:
            l2_reg_loss = l2_loss(self.zc_layer.kernel,alpha=lambda_l2)
            self.cae_se.add_loss(l2_reg_loss)
        
        self.cae_se.compile(Adam(lr = lr))
        self.cae_se.metrics_tensors.append(recon_loss)
        self.cae_se.metrics_names.append("reconstruction_loss")
        self.cae_se.metrics_tensors.append(selfExp_loss)
        self.cae_se.metrics_names.append("se_loss")
        
        if bd_r:
            self.cae_se.metrics_tensors.append(c_a_reg_loss)
            self.cae_se.metrics_names.append("c_a_reg_loss")
            self.cae_se.metrics_tensors.append(bd_reg_loss)
            self.cae_se.metrics_names.append("bd_reg_loss")
        else:
            self.cae_se.metrics_tensors.append(l2_reg_loss)
            self.cae_se.metrics_names.append("l2_reg_loss")            
        #self.cae_se.summary()


    def transfer_weights(self, model_from): #hardcoded for MNIST setup. Needs polishing
        self.cae_se.get_layer("enc_0").set_weights(model_from.get_layer("enc_0").get_weights())
        self.cae_se.get_layer("dec_fin").set_weights(model_from.get_layer("dec_fin").get_weights())

    def train(self, data, epochs, warm_up=20, es_patience=15, lr_patience=5, decay=0.5,  min_lr=1e-6):
        earlystopper_caese = EarlyStopping(monitor='loss', patience=es_patience, verbose=1)
        reduce_lr_caese = ReduceLROnPlateau(monitor='loss', factor=decay, patience=lr_patience, cooldown=2, min_lr=min_lr, verbose=1)

        print("CAE_SE_Warming up...")
        self.cae_se.fit(data,
                   batch_size=self.batch_size,
                   epochs=warm_up,
                   shuffle=False,
                   verbose=0)

        print("\nCAE_SE_ Model fitting...")
        self.cae_se_history = self.cae_se.fit(data,
                                    batch_size=self.batch_size,
                                    epochs=epochs,
                                    shuffle=False,
                                    callbacks=[reduce_lr_caese, earlystopper_caese],
                                    verbose=0)
                    
    def predict(self, data):
        return cae_se.predict(data)

    def save_logs(self, name, path="./"):
        with open(path+name+'.json', 'w') as fp:
            json.dump(str(self.cae_se_history.history), fp)

    def get_C_mat(self):
        return self.cae_se.get_layer("ZC").get_weights()[0]
    
    def get_layer(self,name):
        return self.cae_se.get_layer(name)
    
    def get_output_dim(self,name):
        cae_se.get_layer(name).output_shape[1]
    
    def save_affinity_mat(self, name, path="./"):
        C_mat = self.get_C_mat()
        affinity_mat = 0.5*(C_mat+C_mat.T)
        plt.imsave(path+name+".png", affinity_mat)
        
    def save_weights(self, name, path="./"):
        self.cae_se.save_weights(path+name)

class CAE_SE_FC():
    def __init__(self, n_classes, central_embedding, input_shape, enc_conv_filters, enc_kernel_sizes, dec_conv_filters, dec_kernel_sizes, cae_activation, data_shape, weight_reg, known_dim, suppression, utils):
        self.utils = utils
        self.batch_size = data_shape[0]
        self.n_classes=n_classes
        self.known_dim = known_dim+1
        self.suppression = suppression
        self.input_features = Input(shape = input_shape, name = "input_cae_sefc")
        self.input_labels_oh = Input(batch_shape = (self.batch_size,n_classes), name = "labels_oh_in")
        
        for i in range(len(enc_conv_filters)):
            if i == 0:
                encoded = Conv2D(enc_conv_filters[i], enc_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", activation=cae_activation, input_shape=input_shape, name="enc_"+str(i))(self.input_features)
            else:
                encoded = Conv2D(enc_conv_filters[i], enc_kernel_sizes[i], strides=(2,2), kernel_regularizer = weight_reg, padding="same", activation=cae_activation, name="enc_"+str(i))(encoded)
                
        self.z = Flatten(name="Z")(encoded)
        self.zc_layer = ZC_Layer(self.batch_size, kernel_regularizer = None, no_diag = True, name = "ZC")
        self.zc = self.zc_layer(self.z)
        zc_reshape = Reshape(encoded.shape.as_list()[1:],name="ZC_reshape")(self.zc)        

        for i in range(len(dec_conv_filters)):
            if len(dec_conv_filters)==1:
                self.decoded = Conv2DTranspose(data_shape[-1], dec_kernel_sizes[i], strides=(2,2), kernel_regularizer = None, padding="same", name="dec_fin")(zc_reshape)
            else:
                if i == 0:
                    decoded = Conv2DTranspose(dec_conv_filters[i], dec_kernel_sizes[i], kernel_regularizer = weight_reg, strides=(2,2), padding="same", name="dec_"+str(i-len(conv_filters)//2))(zc_reshape)
                elif i==len(dec_conv_filters):
                    self.decoded = Conv2DTranspose(data_shape[-1], dec_kernel_sizes[i], kernel_regularizer = None, strides=(2,2), padding="same", name="dec_fin")(decoded)
                else:
                    decoded = Conv2DTranspose(dec_conv_filters[i], dec_kernel_sizes[i], kernel_regularizer = weight_reg, strides=(2,2), padding="same", name="dec_"+str(i-len(conv_filters)//2))(decoded)

        fc_pre = Dense(central_embedding, activation = cae_activation, name="fc_pre")(self.z)
        self.fc = Dense(n_classes, activation = "softmax", name="softmax")(self.z)
        self.cL = CenterLossLayer(num_features = central_embedding, num_classes = n_classes, alpha=0.5, name='centerlosslayer')([fc_pre,self.input_labels_oh])

        self.cae_se_fc = Model(inputs=[self.input_features,self.input_labels_oh], outputs=[self.decoded,self.cL,self.fc])
        self.final_model = Model(inputs=[self.input_features], outputs=[self.fc])
        
        
    def compile(self, lambda_rec, lambda_se, sigma, lambda_ca, lambda_cq, lambda_bd, lambda_l2, lambda_cent, lambda_cross, lr, correntropy=True, bd_r=True):
        self.init_lr = lr
        
        recon_loss = lambda_rec*tf.losses.mean_squared_error(K.flatten(self.input_features), K.flatten(self.decoded))
        c_norm_q_loss = c_q_loss(self.zc_layer.kernel,self.input_labels_oh, lambda_cq, self.batch_size)
        cross_loss = crossentropy_loss(self.input_labels_oh, self.fc, alpha = lambda_cross)
        cent_loss = center_loss(self.cL, self.batch_size, alpha=lambda_cent)
        
        self.cae_se_fc.add_loss(recon_loss)
        self.cae_se_fc.add_loss(c_norm_q_loss)
        self.cae_se_fc.add_loss(cross_loss)
        self.cae_se_fc.add_loss(cent_loss)        

        if correntropy:
            selfExp_loss = se_corr_loss(self.z,self.zc,self.batch_size,alpha=lambda_se,sigm=sigma)
        else:
            selfExp_loss = se_loss(self.z,self.zc,self.batch_size,alpha=lambda_se)
        self.cae_se_fc.add_loss(selfExp_loss)
        
        if bd_r:
            c_a_reg_loss = c_a_reg(self.zc_layer.kernel,alpha=lambda_ca)
            bd_reg_loss = bd_reg(self.zc_layer.kernel,k=self.n_classes,alpha=lambda_bd)
            self.cae_se_fc.add_loss(bd_reg_loss)
            self.cae_se_fc.add_loss(c_a_reg_loss)
        else:
            l2_reg_loss = l2_loss(self.zc_layer.kernel,alpha=lambda_l2)
            self.cae_se_fc.add_loss(l2_reg_loss)
            
        self.cae_se_fc.compile(Adam(lr = lr))
        self.cae_se_fc.metrics_tensors.append(recon_loss)
        self.cae_se_fc.metrics_names.append("reconstruction_loss")
        self.cae_se_fc.metrics_tensors.append(selfExp_loss)
        self.cae_se_fc.metrics_names.append("se_loss")
        self.cae_se_fc.metrics_tensors.append(c_norm_q_loss)
        self.cae_se_fc.metrics_names.append("cq_loss")
        self.cae_se_fc.metrics_tensors.append(cross_loss)
        self.cae_se_fc.metrics_names.append("crossentropy_loss")
        self.cae_se_fc.metrics_tensors.append(cent_loss)
        self.cae_se_fc.metrics_names.append("cent_loss")
        
        if bd_r:
            self.cae_se_fc.metrics_tensors.append(c_a_reg_loss)
            self.cae_se_fc.metrics_names.append("c_a_reg_loss")
            self.cae_se_fc.metrics_tensors.append(bd_reg_loss)
            self.cae_se_fc.metrics_names.append("bd_reg_loss")
        else:
            self.cae_se_fc.metrics_tensors.append(l2_reg_loss)
            self.cae_se_fc.metrics_names.append("l2_reg_loss") 
            
        #self.cae_se_fc.summary()
        
        self.final_model.compile("adam","categorical_crossentropy")
        #self.final_model.summary()
        
    def transfer_weights(self, model_from): #hardcoded for COIL20 setup. Needs polishing
        self.cae_se_fc.get_layer("enc_0").set_weights(model_from.get_layer("enc_0").get_weights())
        self.cae_se_fc.get_layer("Z").set_weights(model_from.get_layer("Z").get_weights())
        self.cae_se_fc.get_layer("ZC").set_weights(model_from.get_layer("ZC").get_weights())
        self.cae_se_fc.get_layer("dec_fin").set_weights(model_from.get_layer("dec_fin").get_weights())

    def train(self, data, Tmax, T0, cooldown=3, warm_up=20, es_patience=15, lr_patience=7, decay=0.5,  min_lr=1e-6):
        self.histories = []
        self.cooldown = 0
        labels = self.init_labels
        labels_oh = self.utils.get_one_hot(labels)
        earlystopper_caesefc = EarlyStopping(monitor='loss', patience=es_patience, verbose=1)
        reduce_lr_caesefc = ReduceLROnPlateau(monitor='loss', factor=decay, patience=lr_patience, cooldown=cooldown, min_lr=min_lr, verbose=1)
        
        for i in range(0,Tmax//T0):
            if i==0:
                
                print("Warming up...")
                history = self.cae_se_fc.fit([data,labels_oh],
                          batch_size=self.batch_size,
                          epochs=warm_up,
                          shuffle=False,
                          verbose=0)
            
            else:
                print("\nLabels refined. Real epoch:",i*T0)
                history = self.cae_se_fc.fit([data,labels_oh],
                                batch_size=self.batch_size,
                                epochs=T0,
                                shuffle=False,
                                callbacks=[reduce_lr_caesefc,earlystopper_caesefc],
                                verbose=0)
                 

            self.histories.append(history)
            
            C_mat = self.get_C_mat()
            C_mat = C_mat - np.diag(np.diag(C_mat))
            C_mat_aff = 0.5*(np.abs(C_mat)+np.abs(C_mat.T))
            try:
                labels_new = self.utils.spectral_clustering(C_mat_aff)
                labels = self.utils.map_by_hungarian(labels_new,labels)
                labels_oh = self.utils.get_one_hot(labels)
            except Exception as e:
                print(e)
                print("Exception occured. Using previous labels")

            if i!=0 and len(self.histories[-1].history["loss"])<T0: #if True algorithm earlystopped
                print("Stopping the training process")
                break
                    
    def validate(self,data,labels,es_patience,lr_patience,min_lr,init_cooldown,decay):
        self.final_model.get_layer("enc_0").set_weights(self.cae_se_fc.get_layer("enc_0").get_weights())
        self.final_model.get_layer("Z").set_weights(self.cae_se_fc.get_layer("Z").get_weights())
        self.final_model.get_layer("softmax").set_weights(self.cae_se_fc.get_layer("softmax").get_weights())
        val_loss = self.final_model.evaluate(data,labels)
        labels = np.argmax(labels,axis=-1)
        labels_final_pred = np.argmax(self.final_model.predict(data),axis=-1)
        labels_ref = self.utils.map_by_hungarian(labels_final_pred,labels)
        self.val_acc = accuracy_score(labels,labels_ref)
        self.validations.append(self.val_acc)
        print("Validation accuracy:", self.val_acc)
        print("Validation loss:",val_loss)
        if self.cooldown!=0:
            print("Cooldown")
            self.cooldown-=1
            return True
        else:
            if (len(self.validations)-np.max(np.where(self.validations == np.max(self.validations))[0]))>=es_patience: #last index of maximum value (in case of multiple occurencies)
                return False
            elif (len(self.validations)-np.max(np.where(self.validations == np.max(self.validations))[0]))>=lr_patience:
                new_lr = K.eval(self.cae_se_fc.optimizer.lr)*decay
                if new_lr <= min_lr:
                    K.set_value(self.cae_se_fc.optimizer.lr, min_lr)
                else:
                    K.set_value(self.cae_se_fc.optimizer.lr, new_lr)
                    print("Decreasing LR to:", new_lr)
                self.cooldown = init_cooldown
            return True
                
    def cross_validate_alpha(self,minimum,maximum,granularity,labels_orig):
        alphas = []
        accs = []
        C_mat = self.get_C_mat()
        C_mat = C_mat - np.diag(np.diag(C_mat))
        print("Alpha crossvalidation")
        for i in np.arange(minimum,maximum,(maximum-minimum)/granularity):
            alphas.append(i)
            c_thr = self.utils.thrC(C_mat,i)
            try:
                labels, L = self.utils.post_proC(c_thr,self.n_classes,self.known_dim,self.suppression,with_diag=True)
                labels_ref = self.utils.map_by_hungarian(labels,labels_orig)
                accs.append(accuracy_score(labels_orig,labels_ref))
                print("Alpha",i,"| Acc =",accs[-1])
            except Exception as e:
                print(e)
                accs.append(0.0)

        ind = np.argmax(accs)
        self.best_alpha = alphas[ind]
        print("Best alpha is",self.best_alpha,"with accuracy score:",accs[ind])
        
    def get_initial_labels(self):
        C_mat = self.get_C_mat()
        C_mat = C_mat - np.diag(np.diag(C_mat))
        c_thr = self.utils.thrC(C_mat,self.best_alpha)
        labels, L = self.utils.post_proC(c_thr,self.n_classes,self.known_dim,self.suppression,with_diag=True)
        self.init_labels = labels
        print("Labels initialized for alpha:",self.best_alpha)
        
    def save_logs(self, name, path="./"):
        new_hist = {}            
        for i,k in enumerate(self.histories):
            for key in k.history.keys():
                if i == 0:
                    new_hist[key] = k.history[key]
                else:
                    new_hist[key].extend(k.history[key])
            if i==0 and ("lr" not in new_hist.keys()):
                new_hist["lr"]=[self.init_lr]*len(new_hist["loss"])
                
        #new_hist["validation_loss"] = self.validations
        new_hist["test_accuracy"] = self.test_acc
        with open(path+name+'.json', 'w') as fp:
            json.dump(str(new_hist), fp)
            
    def test_final_model(self,data_train,labels_train,data_test,labels_test):
        labels_final_pred = np.argmax(self.final_model.predict(data_train),axis=-1)
        labels_ref = self.utils.map_by_hungarian(labels_final_pred,labels_train)
        self.train_acc = accuracy_score(labels_train,labels_ref)
        print("Train accuracy:", self.train_acc)
        labels_final_pred = np.argmax(self.final_model.predict(data_test),axis=-1)
        labels_ref = self.utils.map_by_hungarian(labels_final_pred,labels_test)
        self.test_acc = accuracy_score(labels_test,labels_ref)
        print("Test accuracy:", self.test_acc)
        
    def get_C_mat(self):
        return self.cae_se_fc.get_layer("ZC").get_weights()[0]
    
    def save_affinity_mat(self, name, path="./"):
        C_mat = self.get_C_mat()
        affinity_mat = 0.5*(C_mat+C_mat.T)
        plt.imsave(path+name+".png", affinity_mat)
    
    def save_weights(self, name, path="./"):
        self.cae_se_fc.save_weights(path+name)