import os
import numpy as np
import keras.regularizers
import tensorflow as tf
from sklearn.model_selection import train_test_split, StratifiedKFold
from keras import backend as K
import random
import itertools
from models_coil import CAE, CAE_SE, CAE_SE_FC
from tfdeterminism import patch
from utilities import Utils
import sys

patch()
SEED=43
os.environ["CUDA_VISIBLE_DEVICES"]="1"
os.environ['PYTHONHASHSEED']=str(SEED)
os.environ['HOROVOD_FUSION_THRESHOLD']='0'
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_random_seed(SEED)

gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.tensorflow_backend.set_session(sess)

num_experiments = 5
num_classes = 20

X = np.load("./datasets/coil20_DSCNet_data.npy")
y = np.load("./datasets/coil20_DSCNet_labels.npy")
utils = Utils(num_classes)

loss_combinations = list(itertools.product(["correntropy", "sq_fro"], ["BD", "L2"]))
try:
    os.mkdir("Coil20_experiments_CorrX10")
except:
    pass

run_path = "Coil20_experiments_CorrX10/all"
try:
    os.mkdir(run_path)
except:
    pass

data_train, data_test = X, X
labels_orig_train, labels_orig_test = y, y
data_train, labels_orig_train = utils.sort_dataset(data_train,labels_orig_train)
data_test, labels_orig_test = utils.sort_dataset(data_test,labels_orig_test)
mean_train = np.mean(data_train)
std_train = np.std(data_train)
data_train = (data_train-mean_train)/std_train
data_test = (data_test-mean_train)/std_train

cae_trained = False
for se_reg,c_reg in loss_combinations:        
    print("\n",se_reg,c_reg)
    #Coil20 ARCHITECTURE SETUP
    input_shape = (32,32,1)
    enc_conv_filters=[15]
    enc_kernel_sizes=[(3,3)]
    dec_conv_filters=[15]
    dec_kernel_sizes=[(3,3)]
    cae_activation = "relu"
    batch_size = data_train.shape[0]

    lambda_rec = 1
    lambda_se = 300
    lambda_cross = 6
    lambda_cent = 3
    lambda_cq = 8
    lambda_bd = 1
    lambda_ca = 1
    corr_sigma = 1.
    lambda_l2 = 1
    
    lr = 1e-3
    if se_reg == "correntropy":
        corren = True
    else:
        corren = False
    if c_reg == "BD":
        bd_r = True
    else:
        bd_r = False

    if c_reg == "L2" or se_reg=="sq_fro":
        continue
        
    weight_reg = None
    ################################################################################
    #CAE
    
    if not cae_trained:
        run_model_path = run_path+"/CAE"
        try:
            os.mkdir(run_model_path)
        except:
            pass

        epochs = 2000
        es_patience = 10
        lr_patience = 5
        decay = 0.5
        min_lr = 1e-6

        cae = CAE(input_shape=input_shape, enc_conv_filters=enc_conv_filters, enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters,
                    dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, data_shape = data_train.shape, weight_reg=weight_reg)
        cae.compile(lambda_rec=lambda_rec, lr=lr)
        cae.train(data_train, epochs=epochs, es_patience=es_patience, lr_patience=lr_patience, decay=decay,  min_lr=min_lr)
        cae.save_logs("cae_logs",path=run_model_path+"/")
        cae.save_weights("cae_weights", path=run_model_path+"/")
        cae_trained = True
    else:
        print("Using pretrained CAE model")

    ################################################################################
    #CAE_SE
    run_model_path = run_path+"/CAE_SE"+"_"+se_reg+"_"+c_reg
    try:
        os.mkdir(run_model_path)
    except:
        pass
    epochs = 2500
    warm_up = 20
    es_patience = 15
    lr_patience = 5
    decay = 0.5
    min_lr = 1e-6

    cae_se = CAE_SE(input_shape=input_shape, enc_conv_filters=enc_conv_filters, enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters,
                dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, data_shape = data_train.shape, weight_reg=weight_reg)

    cae_se.compile(lambda_rec=lambda_rec, lambda_se=lambda_se, sigma=corr_sigma, lambda_ca=lambda_ca, lambda_bd=lambda_bd, lambda_l2=lambda_l2, num_classes=num_classes, 
                   lr=lr, correntropy=corren, bd_r=bd_r)

    cae_se.transfer_weights(cae)
    cae_se.train(data_train, warm_up=warm_up, epochs=epochs, es_patience=es_patience, lr_patience=lr_patience, decay=decay,  min_lr=min_lr)
    cae_se.save_affinity_mat("cae_se_affinity", path=run_model_path+"/")
    cae_se.save_weights("cae_se_weights", path=run_model_path+"/")
    cae_se.save_logs("cae_se_logs", path=run_model_path+"/")

    ################################################################################
    #CAE_SE_FC
    run_model_path = run_path+"/CAE_SE_FC"+"_"+se_reg+"_"+c_reg
    try:
        os.mkdir(run_model_path)
    except:
        pass

    central_embedding = 2
    Tmax = 8000
    T0 = 50
    warm_up = 100
    es_patience = 15
    lr_patience = 5
    decay = 0.5
    min_lr = 1e-6
    known_dim = 11
    suppression = 8
    cooldown = 3
    alpha_min, alpha_max = 0.08, 0.8
    alpha_granularity = 15
    lr = lr/10

    cae_se_fc = CAE_SE_FC(n_classes=num_classes, central_embedding=central_embedding, input_shape=input_shape, enc_conv_filters=enc_conv_filters, 
                          enc_kernel_sizes=enc_kernel_sizes, dec_conv_filters=dec_conv_filters, dec_kernel_sizes=dec_kernel_sizes, cae_activation=cae_activation, 
                          data_shape = data_train.shape, weight_reg=weight_reg, known_dim=known_dim,suppression=suppression,utils=utils)

    cae_se_fc.compile(lambda_rec=lambda_rec, lambda_se=lambda_se, sigma=corr_sigma, lambda_ca=lambda_ca, lambda_cq=lambda_cq, lambda_bd=lambda_bd, lambda_l2=lambda_l2,
                      lambda_cent=lambda_cent, lambda_cross=lambda_cross, lr=lr, correntropy=corren, bd_r=bd_r)
    cae_se_fc.transfer_weights(cae_se)
    cae_se_fc.cross_validate_alpha(alpha_min,alpha_max,alpha_granularity, labels_orig = labels_orig_train)
    cae_se_fc.get_initial_labels()
    cae_se_fc.train(data_train, Tmax=Tmax, T0=T0, warm_up=warm_up,
                    es_patience=es_patience, lr_patience=lr_patience, decay=decay, cooldown=cooldown, min_lr=min_lr)
    cae_se_fc.test_final_model(data_train, labels_orig_train, data_test, labels_orig_test)
    cae_se_fc.save_affinity_mat("cae_se_fc_affinity", path=run_model_path+"/")
    cae_se_fc.save_weights("cae_se_fc_weights", path=run_model_path+"/")
    cae_se_fc.save_logs("cae_se_fc_logs", path=run_model_path+"/")