import os
import numpy as np
import keras.regularizers
import tensorflow as tf
from sklearn.model_selection import train_test_split
from keras import backend as K
import random
import itertools
from models_mnist import CAE, CAE_SE, CAE_SE_FC
from tfdeterminism import patch
from utilities import Utils
import sys

patch()
SEED=43
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['PYTHONHASHSEED']=str(SEED)
os.environ['HOROVOD_FUSION_THRESHOLD']='0'
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_random_seed(SEED)

gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.tensorflow_backend.set_session(sess)

num_experiments = 15
num_classes = 10
X = np.load("./datasets/mnist_train_data.npy")
X = np.expand_dims(X,axis=-1)
y = np.load("./datasets/mnist_train_targets.npy")
data_per_class = 400
utils = Utils(num_classes)
loss_combinations = list(itertools.product(["correntropy", "sq_fro"], ["BD", "L2"]))

try:
    os.mkdir("MNIST_experiments_CorrX0p1")
except:
    pass

for c_num in range(num_experiments):
    run_path = "MNIST_experiments_CorrX0p1/"+str(c_num)
    try:
        os.mkdir(run_path)
    except:
        pass
    
    X_bal, labels_bal = utils.balanced_sampling(X, y, data_per_class, chunk_num=c_num, random_state=SEED)
    data_train, data_test, labels_orig_train, labels_orig_test = train_test_split(X_bal,labels_bal,train_size=0.7,random_state=SEED)
    #keeping the data in order for visualization does not affect the learning results
    data_train, labels_orig_train = utils.sort_dataset(data_train, labels_orig_train)
    data_test, labels_orig_test = utils.sort_dataset(data_test, labels_orig_test)
    mean_train = np.mean(data_train)
    std_train = np.std(data_train)
    data_train = (data_train-mean_train)/std_train
    data_test = (data_test-mean_train)/std_train #keeping only the knowledge of the train set 


    cae_trained = False
    for se_reg,c_reg in loss_combinations:
        print(c_num,se_reg,c_reg)
        #MNIST ARCHITECTURE SETUP
        input_shape = (28,28,1)
        enc_conv_filters=[10,20,30]
        enc_kernel_sizes=[(5,5),(3,3),(3,3)]
        dec_conv_filters=[20,10,1]
        dec_kernel_sizes=[(3,3),(3,3),(5,5)]
        cae_activation = "relu"
        num_classes = 10
        batch_size = data_train.shape[0]

        lambda_rec = 1
        lambda_se = 30
        lambda_cross = 72
        lambda_cent = 36
        lambda_cq = 1
        lambda_weights = 1e-3
        lr = 1e-3
        lambda_bd = 1
        lambda_ca = 1
        corr_sigma = 0.2
        lambda_l2 = 1
        if se_reg == "correntropy":
            corren = True
        else:
            corren = False
        if c_reg == "BD":
            bd_r = True
        else:
            bd_r = False

        if se_reg == "sq_fro" or c_reg == "BD":
            continue    
            
        if se_reg == "correntropy" and c_reg == "L2":
            lambda_se /= 10

        weight_reg = keras.regularizers.l2(lambda_weights)    
        ################################################################################
        #CAE
        
        if not cae_trained:
            run_model_path = run_path+"/CAE"
            try:
                os.mkdir(run_model_path)
            except:
                pass

            epochs = 5000
            es_patience = 10
            lr_patience = 5
            decay = 0.5
            min_lr = 1e-6

            cae = CAE(input_shape=input_shape, enc_conv_filters=enc_conv_filters, enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters,
                        dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, data_shape = data_train.shape, weight_reg=weight_reg)

            cae.compile(lambda_rec=lambda_rec, lr=lr)
            cae.train(data_train, epochs=epochs, es_patience=es_patience, lr_patience=lr_patience, decay=decay,  min_lr=min_lr)
            cae.save_logs("cae_logs",path=run_model_path+"/")
            cae.save_weights("cae_weights", path=run_model_path+"/")
            cae_trained = True
        else:
            print("Using pretrained CAE model")

        ################################################################################
        #CAE_SE
        run_model_path = run_path+"/CAE_SE"+"_"+se_reg+"_"+c_reg
        try:
            os.mkdir(run_model_path)
        except:
            pass
        epochs = 10000
        warm_up = 20
        es_patience = 15
        lr_patience = 5
        decay = 0.5
        min_lr = 1e-5

        cae_se = CAE_SE(input_shape=input_shape, enc_conv_filters=enc_conv_filters, enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters,
                    dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, data_shape = data_train.shape, weight_reg=weight_reg)

        cae_se.compile(lambda_rec=lambda_rec, lambda_se=lambda_se, sigma=corr_sigma, lambda_ca=lambda_ca, lambda_bd=lambda_bd, lambda_l2=lambda_l2, num_classes=num_classes, 
                       lr=lr, correntropy=corren, bd_r=bd_r)
       
        cae_se.transfer_weights(cae)
        cae_se.train(data_train, warm_up=warm_up, epochs=epochs, es_patience=es_patience, lr_patience=lr_patience, decay=decay,  min_lr=min_lr)
        cae_se.save_affinity_mat("cae_se_affinity", path=run_model_path+"/")
        cae_se.save_weights("cae_se_weights", path=run_model_path+"/")
        cae_se.save_logs("cae_se_logs", path=run_model_path+"/")

        ################################################################################
        #CAE_SE_FC
        run_model_path = run_path+"/CAE_SE_FC"+"_"+se_reg+"_"+c_reg
        try:
            os.mkdir(run_model_path)
        except:
            pass

        central_embedding = 2
        Tmax = 9000
        T0 = 30
        warm_up = 50
        es_patience = 15
        lr_patience = 5
        decay = 0.5
        min_lr = 1e-6
        known_dim = 12
        suppression = 12
        cooldown = 3
        alpha_min, alpha_max = 0.04, 0.8
        alpha_granularity = 15

        cae_se_fc = CAE_SE_FC(n_classes=num_classes, central_embedding=central_embedding, input_shape=input_shape, enc_conv_filters=enc_conv_filters, 
                              enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters, dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, 
                              data_shape = data_train.shape, weight_reg=weight_reg, known_dim=known_dim,suppression=suppression,utils=utils)

        cae_se_fc.compile(lambda_rec=lambda_rec, lambda_se=lambda_se, sigma=corr_sigma, lambda_ca=lambda_ca, lambda_cq=lambda_cq, lambda_bd=lambda_bd, lambda_l2=lambda_l2,
                          lambda_cent=lambda_cent, lambda_cross=lambda_cross, lr=lr, correntropy=corren, bd_r=bd_r)
        cae_se_fc.transfer_weights(cae_se)
        cae_se_fc.cross_validate_alpha(alpha_min,alpha_max,alpha_granularity, labels_orig = labels_orig_train[:,0])
        cae_se_fc.get_initial_labels()
        cae_se_fc.train(data_train, Tmax=Tmax, T0=T0, warm_up=warm_up,
                        es_patience=es_patience, lr_patience=lr_patience, decay=decay, cooldown=cooldown, min_lr=min_lr)
        cae_se_fc.test_final_model(data_train, labels_orig_train[:,0], data_test, labels_orig_test[:,0])
        cae_se_fc.save_affinity_mat("cae_se_fc_affinity", path=run_model_path+"/")
        cae_se_fc.save_weights("cae_se_fc_weights", path=run_model_path+"/")
        cae_se_fc.save_logs("cae_se_fc_logs", path=run_model_path+"/")
