import numpy as np
import scipy.io as sio
import os
import datetime
import time
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

from HGRDecNetwork_Res50 import train_mynetwork_HighMemory, train_mynetwork_LowMemory
from datadPreprocessing import Get_data_pad, Get_lable_indexes, Get_data_set

# Centralize Training Configuration Parameters
is_HighMemory = 1  #Option for High Memory Usage Mode
is_Berlin = 1
is_Houston2018 = 0

Max_epochs = 200  # Maximum Training Epochs
nminibatch_size = 32  # Batch Size per Training Iteration

corewidth = 4  # Maximum Core Area Size: pad_width if exceeded corewidth+1 triggers the activation of sparsification
kmethod = 0  # Sparsification Method,0-radius-first 1-chord(perimeter)priority
# pad_width = 17  # Number of Zeros to Pad,typically set to 3
pad_width = (corewidth + 1) * 2 + 1  #
# pad_width = corewidth * 2 + 1
kaugmentation = 3  # Data Augmentation, 0 no augmentation，1-3 by adding 1-3 samples rotated by 90 degrees sequentially

min_AO_Save = 0.7  # Minimum Threshold for A0 Saving is in an 'or' relationship with min_AA_Save
min_AA_Save = 0.7  # Minimum Threshold for AA Saving is in an 'or' relationship with min_AO_Save

flearning_rate = 1e-3
bis_scheduler = 1  # If Automatic Adjustment of Learning Rate

bis_show = 0  # Option to Display Images
bis_random = 1  # Enable Randomization


if is_Berlin:
    HSI = sio.loadmat('HS-SAR Berlin/data_HS_LR.mat')
    HSI = HSI['data_HS_LR']
    HSI = HSI.astype(np.float32)

    SAR = sio.loadmat('HS-SAR Berlin/data_SAR_HR.mat')
    SAR = SAR['data_SAR_HR']
    SAR = SAR.astype(np.float32)

    B_map = sio.loadmat('HS-SAR Berlin/TrainImage.mat')
    TR_map = B_map['TrainImage']
    C_map = sio.loadmat('HS-SAR Berlin/TestImage.mat')
    TE_map = C_map['TestImage']

elif is_Houston2018:
    HSI = sio.loadmat('Houston2018/houston_hsi.mat')
    HSI = HSI['houston_hsi']
    HSI = HSI.astype(np.float32)

    SAR = sio.loadmat('Houston2018/houston_lidar.mat')
    SAR = SAR['houston_lidar']
    SAR = SAR.astype(np.float32)

    TR_map = sio.loadmat('Houston2018/Houston2018_TR.mat')
    TR_map = TR_map['TR_map']

    TE_map = sio.loadmat('Houston2018/Houston2018_TE.mat')
    TE_map = TE_map['TE_map']

else:
    HSI = sio.loadmat('HS-SAR-DSM Augsburg\data_HS_LR.mat')
    HSI = HSI['data_HS_LR']
    HSI = HSI.astype(np.float32)

    # SAR = sio.loadmat('HS-SAR-DSM Augsburg\data_SAR_HR.mat')
    # SAR = SAR['data_SAR_HR']
    # SAR = SAR.astype(np.float32)

    SAR = sio.loadmat('HS-SAR-DSM Augsburg\data_DSM.mat')
    SAR = SAR['data_DSM']
    SAR = SAR.astype(np.float32)

    B_map = sio.loadmat('HS-SAR-DSM Augsburg\TrainImage.mat')
    TR_map = B_map['TrainImage']
    C_map = sio.loadmat('HS-SAR-DSM Augsburg\TestImage.mat')
    TE_map = C_map['TestImage']



if len(SAR.shape) == 2:
    SAR = np.reshape(SAR,(SAR.shape[0],SAR.shape[1],1))

nt = time.strftime('%M_%S', time.localtime(time.time()))
if not os.path.exists('./results/'):
    os.mkdir('./results/')
current_time = datetime.datetime.now()
current_time = str(current_time)
cur_time = current_time[0:13]
path = './results/' + cur_time
if not os.path.exists(path):
    os.mkdir(path)

script_path = os.path.abspath(__file__)
script_name = os.path.basename(script_path)
with open(path + '/' + nt + 'logger.txt', 'a') as file:
    file.write('the currently executing file: {}\n'.format(script_name))
# Data Preprocessing
t0 = time.time()
SARpad = Get_data_pad(SAR, pad_width)
HSIpad = Get_data_pad(HSI, pad_width)
TR_lable_indexes, TR_lable_pad, TR_lable, TR_lable_onehot, k_maxL = Get_lable_indexes(TR_map, pad_width, -1,
                                                                                      kaugmentation)
TR_SAR_DataSet, patchsize, SARz = Get_data_set(SARpad, TR_lable_indexes, pad_width, corewidth, kmethod, kaugmentation)
TR_HSI_DataSet, __, HSIz = Get_data_set(HSIpad, TR_lable_indexes, pad_width, corewidth, kmethod, kaugmentation)

TE_lable_indexes, TE_lable_pad, TE_lable, TE_lable_onehot, k_maxL2 = Get_lable_indexes(TE_map, pad_width, -1)
if is_HighMemory:
    TE_SAR_DataSet,__, __ = Get_data_set(SARpad, TE_lable_indexes,pad_width,corewidth,kmethod)
    TE_HSI_DataSet,__, __ = Get_data_set(HSIpad, TE_lable_indexes,pad_width,corewidth,kmethod)

t1 = time.time()
spend1 = t1 - t0
print('Data preprocessing Time: {}'.format(spend1))
with open(path + '/' + nt + 'logger.txt', 'a') as file:
    file.write('Data preprocessing Time: {}\n'.format(spend1))



# Train Model
if is_HighMemory:
    val_lost, val_acc, val_aa, feature = train_mynetwork_HighMemory(TR_HSI_DataSet, TR_SAR_DataSet, TE_HSI_DataSet, TE_SAR_DataSet,
                                                        TR_lable_onehot, TE_lable_onehot, HSIz, SARz, patchsize, k_maxL, pad_width,
                                                        kaugmentation, nt, path, min_AA_Save, min_AO_Save, flearning_rate, Max_epochs,
                                                        nminibatch_size,-1,bis_scheduler,bis_show,bis_random)
else:
    val_lost, val_acc, val_aa, feature = train_mynetwork_LowMemory(TR_HSI_DataSet, TR_SAR_DataSet, TE_lable_indexes,
                                                        TR_lable_onehot, TE_lable_onehot, HSIz, SARz, patchsize, k_maxL, pad_width,
                                                        kaugmentation, nt, path, SARpad, HSIpad, corewidth, kmethod, min_AA_Save,
                                                        min_AO_Save, flearning_rate, Max_epochs, nminibatch_size, -1,bis_scheduler,
                                                        bis_show, bis_random)


sio.savemat(path + '/TestLabel.mat', {'TestLabel': TE_lable, 'TrainLabel': TR_lable})
print("corewidth:{} , kmethod:{} , pad_width:{} , kaugmentation:{}\n".format(corewidth, kmethod, pad_width,
                                                                             kaugmentation))
print("Test Lost: %f" % (min(val_lost)))
print("Test Accuracy: %f" % (max(val_acc)))
print("Test AA: %f" % (max(val_aa)))
with open(path + '/' + nt + 'logger.txt', 'a') as file:
    file.write("corewidth:{} , kmethod:{} , pad_width:{} , kaugmentation:{}\n".format(corewidth, kmethod, pad_width,
                                                                                      kaugmentation))
    file.write("Test Lost: %f\n" % (min(val_lost)))
    file.write("Test Accuracy: %f\n" % (max(val_acc)))
    file.write("Test AA: %f\n" % (max(val_aa)))
    file.write("pad_width: %f\n" % pad_width)
    current_time = datetime.datetime.now()
    file.write(str(current_time))
    file.write('\n')

print("Results was save in \'" + path + '\'')
