import os
import numpy as np
import keras.regularizers
import tensorflow as tf
from sklearn.model_selection import train_test_split, StratifiedKFold
from keras import backend as K
import random
import itertools
from models_coil import CAE, CAE_SE, CAE_SE_FC
from tfdeterminism import patch
from utilities import Utils
import sys

patch()
SEED=43
os.environ["CUDA_VISIBLE_DEVICES"]="1"
os.environ['PYTHONHASHSEED']=str(SEED)
os.environ['HOROVOD_FUSION_THRESHOLD']='0'
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_random_seed(SEED)

gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.tensorflow_backend.set_session(sess)

num_experiments = 5
num_classes = 20

X = np.load("./datasets/coil20_DSCNet_data.npy")
y = np.load("./datasets/coil20_DSCNet_labels.npy")
utils = Utils(num_classes)
skf = StratifiedKFold(n_splits=num_experiments,random_state=SEED,shuffle=True)

#keeping only the knowledge of the train set 

loss_combinations = list(itertools.product(["correntropy", "sq_fro"], ["BD", "L2"]))
try:
    os.mkdir("Coil20_experiments_CorrX20")
except:
    pass

for c_num,(train_index, test_index) in enumerate(skf.split(X, y)):
    run_path = "Coil20_experiments_CorrX20/"+str(c_num)
    try:
        os.mkdir(run_path)
    except:
        pass
    
    data_train, data_test = X[train_index], X[test_index]
    labels_orig_train, labels_orig_test = y[train_index], y[test_index]
    data_train, labels_orig_train = utils.sort_dataset(data_train,labels_orig_train)
    data_test, labels_orig_test = utils.sort_dataset(data_test,labels_orig_test)
    mean_train = np.mean(data_train)
    std_train = np.std(data_train)
    data_train = (data_train-mean_train)/std_train
    data_test = (data_test-mean_train)/std_train
    
    cae_trained = False
    for se_reg,c_reg in loss_combinations:        
        print("\n",c_num,se_reg,c_reg)
        #Coil20 ARCHITECTURE SETUP
        input_shape = (32,32,1)
        enc_conv_filters=[15]
        enc_kernel_sizes=[(3,3)]
        dec_conv_filters=[15]
        dec_kernel_sizes=[(3,3)]
        cae_activation = "relu"
        batch_size = data_train.shape[0]

        lambda_rec = 1
        lambda_se = 30
        lambda_cross = 6
        lambda_cent = 3
        lambda_cq = 8
        #lambda_weights = 1e-3
        lambda_bd = 1
        lambda_ca = 1
        corr_sigma = 1.
        lambda_l2 = 1
        
        lr = 1e-3
        if se_reg == "correntropy":
            corren = True
        else:
            corren = False
        if c_reg == "BD":
            bd_r = True
        else:
            bd_r = False

        if se_reg == "sq_fro" or c_reg == "BD":
            continue    
            
        if se_reg == "correntropy" and c_reg == "L2":
            lambda_se *= 20            
            
        weight_reg = None
        ################################################################################
        #CAE
        
        if not cae_trained:
            run_model_path = run_path+"/CAE"
            try:
                os.mkdir(run_model_path)
            except:
                pass

            epochs = 2000
            es_patience = 10
            lr_patience = 5
            decay = 0.5
            min_lr = 1e-6

            cae = CAE(input_shape=input_shape, enc_conv_filters=enc_conv_filters, enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters,
                        dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, data_shape = data_train.shape, weight_reg=weight_reg)
            cae.compile(lambda_rec=lambda_rec, lr=lr)
            cae.train(data_train, epochs=epochs, es_patience=es_patience, lr_patience=lr_patience, decay=decay,  min_lr=min_lr)
            cae.save_logs("cae_logs",path=run_model_path+"/")
            cae.save_weights("cae_weights", path=run_model_path+"/")
            cae_trained = True
        else:
            print("Using pretrained CAE model")

        ################################################################################
        #CAE_SE
        run_model_path = run_path+"/CAE_SE"+"_"+se_reg+"_"+c_reg
        try:
            os.mkdir(run_model_path)
        except:
            pass
        epochs = 2500
        warm_up = 20
        es_patience = 15
        lr_patience = 5
        decay = 0.5
        min_lr = 1e-6

        cae_se = CAE_SE(input_shape=input_shape, enc_conv_filters=enc_conv_filters, enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters,
                    dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, data_shape = data_train.shape, weight_reg=weight_reg)

        cae_se.compile(lambda_rec=lambda_rec, lambda_se=lambda_se, sigma=corr_sigma, lambda_ca=lambda_ca, lambda_bd=lambda_bd, lambda_l2=lambda_l2, num_classes=num_classes, 
                       lr=lr, correntropy=corren, bd_r=bd_r)

        cae_se.transfer_weights(cae)
        cae_se.train(data_train, warm_up=warm_up, epochs=epochs, es_patience=es_patience, lr_patience=lr_patience, decay=decay,  min_lr=min_lr)
        cae_se.save_affinity_mat("cae_se_affinity", path=run_model_path+"/")
        cae_se.save_weights("cae_se_weights", path=run_model_path+"/")
        cae_se.save_logs("cae_se_logs", path=run_model_path+"/")

        ################################################################################
        #CAE_SE_FC
        run_model_path = run_path+"/CAE_SE_FC"+"_"+se_reg+"_"+c_reg
        try:
            os.mkdir(run_model_path)
        except:
            pass

        central_embedding = 2
        Tmax = 4000
        T0 = 50
        warm_up = 100
        es_patience = 15
        lr_patience = 5
        decay = 0.5
        min_lr = 1e-6
        known_dim = 11
        suppression = 8
        cooldown = 3
        alpha_min, alpha_max = 0.08, 0.8
        alpha_granularity = 15
        lr = lr/10

        cae_se_fc = CAE_SE_FC(n_classes=num_classes, central_embedding=central_embedding, input_shape=input_shape, enc_conv_filters=enc_conv_filters, 
                              enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters, dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, 
                              data_shape = data_train.shape, weight_reg=weight_reg, known_dim=known_dim,suppression=suppression,utils=utils)

        cae_se_fc.compile(lambda_rec=lambda_rec, lambda_se=lambda_se, sigma=corr_sigma, lambda_ca=lambda_ca, lambda_cq=lambda_cq, lambda_bd=lambda_bd, lambda_l2=lambda_l2,
                          lambda_cent=lambda_cent, lambda_cross=lambda_cross, lr=lr, correntropy=corren, bd_r=bd_r)
        cae_se_fc.transfer_weights(cae_se)
        cae_se_fc.cross_validate_alpha(alpha_min,alpha_max,alpha_granularity, labels_orig = labels_orig_train)
        cae_se_fc.get_initial_labels()
        cae_se_fc.train(data_train, Tmax=Tmax, T0=T0, warm_up=warm_up,
                        es_patience=es_patience, lr_patience=lr_patience, decay=decay, cooldown=cooldown, min_lr=min_lr)
        cae_se_fc.test_final_model(data_train, labels_orig_train, data_test, labels_orig_test)
        cae_se_fc.save_affinity_mat("cae_se_fc_affinity", path=run_model_path+"/")
        cae_se_fc.save_weights("cae_se_fc_weights", path=run_model_path+"/")
        cae_se_fc.save_logs("cae_se_fc_logs", path=run_model_path+"/")
