import scipy.io as sio
import numpy as np
from utils import convert_to_one_hot

#Data Imputation
def Get_data_pad(sourcedata, set_pad_width):
    #Data Imputation
    (m_source, n_source, z_source) = sourcedata.shape
    temp = sourcedata[:, :, 0]  # First Layer Extraction
    temp2 = np.pad(temp, set_pad_width, 'symmetric')  # Symmetric Padding      (355, 1911)
    [m2_source, n2_source] = temp2.shape
    data_pad = np.empty((m2_source, n2_source, z_source), dtype='float32')
    #Data Normalization
    for i in range(z_source):
        ma = np.max(sourcedata[:, :, i])  # Maximization within Layer
        mi = np.min(sourcedata[:, :, i])  # Minimization within Layer
        sourcedata[:, :, i] = (sourcedata[:, :, i] - mi) / (ma - mi)  # Normalization
        temp = sourcedata[:, :, i]
        temp2 = np.pad(temp, set_pad_width, 'symmetric')
        data_pad[:, :, i] = temp2
    return data_pad

#Acquiring the sequence of coordinates corresponding to labels
def Get_lable_indexes(labledata, set_pad_width, k_maxL =-1, kaugmentation = 0):
    kaugmentation = int(kaugmentation)
    lablepad = np.pad(labledata, set_pad_width, 'constant') #symmetric constant
    [ind1_DataSet, ind2_DataSet] = np.where(lablepad != 0)
    DataLen = len(ind1_DataSet)
    ind1_DataSet = np.reshape(ind1_DataSet,[1,-1])
    ind2_DataSet = np.reshape(ind2_DataSet, [1, -1])
    # print(ind1_DataSet.shape)
    lable_indexes = np.concatenate((ind1_DataSet,ind2_DataSet), axis = 0)
    # print(lable_indexes.shape)
    SetLable = np.zeros((DataLen, 1), dtype='uint8')
    for i in range(DataLen):
        SetLable[i] = lablepad[lable_indexes[0, i], lable_indexes[1, i]]


    #Conversion to OneHot Encoding
    k_maxL = np.max([np.max(SetLable),k_maxL])

    SetLable_onehot = convert_to_one_hot(SetLable - 1, k_maxL)
    SetLable_onehot = SetLable_onehot.T
    if kaugmentation:
        TSetLable_onehot = SetLable_onehot
        for i in range(kaugmentation):
            TSetLable_onehot = np.concatenate((TSetLable_onehot, SetLable_onehot),axis = 0)
        SetLable_onehot = TSetLable_onehot
    return lable_indexes, lablepad, SetLable, SetLable_onehot, k_maxL


def Get_data_set(data_pad, lable_indexes, set_pad_width, set_corewidth, kmethod = 0, kaugmentation = 0):
    #Data Imputation
    detPW = 0
    corewidthAdd1 = set_corewidth +1
    if set_pad_width > corewidthAdd1:
        detPW = set_pad_width - set_corewidth  #Sparse Sampling Evaluation
    set_pad_size = np.min([2 * set_pad_width + 1, 2 * set_corewidth + 3])
    set_pad_size_sub1 = set_pad_size -1
    (m_datapad, n_datapad, z_datapad) = data_pad.shape
    #Dataset Preparation
    DataLen = len(lable_indexes[0, :])
    AuDataLen = DataLen
    if kaugmentation:
        kaugmentation = int(kaugmentation)
        if kaugmentation > 3:
            kaugmentation=3
        AuDataLen = DataLen*(kaugmentation + 1)
    DataSet = np.zeros((AuDataLen, set_pad_size * set_pad_size * z_datapad), dtype='float32')

    if detPW:
        if kmethod == 0:
            for i in range(DataLen):
                #Writing Core Area Data
                xs = lable_indexes[0, i] - set_corewidth
                xe = xs + 2 * set_corewidth + 1
                ys = lable_indexes[1, i] - set_corewidth
                ye = ys + 2 * set_corewidth + 1
                Sample = data_pad[xs:xe, ys:ye, :]
                # Sampling Data Preparation
                Sample = np.pad(Sample,((1,1),(1,1),(0,0)),'constant')
                # print(Sample.shape)
                # Writing Sample Data
                if detPW > set_corewidth:
                    # block1 LU - -
                    for k in range(corewidthAdd1):
                        Sample[0,k,:] = data_pad[lable_indexes[0, i] - set_pad_width + k, lable_indexes[1, i] - set_pad_width + k, :]
                    # block2 MU - 0
                    for k in range(corewidthAdd1):
                        Sample[0,k+corewidthAdd1,:] = data_pad[lable_indexes[0, i] - set_pad_width + k, lable_indexes[1, i], :]
                    # block3 RU - +
                    for k in range(corewidthAdd1):
                        Sample[k,set_pad_size_sub1,:] = data_pad[lable_indexes[0, i] - set_pad_width + k, lable_indexes[1, i] + set_pad_width - k, :]
                    # block4 RM 0 +
                    for k in range(corewidthAdd1):
                        Sample[k + corewidthAdd1, set_pad_size_sub1, :] = data_pad[lable_indexes[0, i], lable_indexes[1, i] + set_pad_width - k, :]
                    # block5 RD + +
                    for k in range(corewidthAdd1):
                        Sample[set_pad_size_sub1, set_pad_size_sub1 - k, :] = data_pad[lable_indexes[0, i] + set_pad_width - k , lable_indexes[1, i] + set_pad_width - k, :]
                    # block6 MD + 0
                    for k in range(corewidthAdd1):
                        Sample[set_pad_size_sub1, set_pad_size_sub1 - k - corewidthAdd1, :] = data_pad[lable_indexes[0, i] + set_pad_width - k,lable_indexes[1, i],:]
                    # block7 LD + -
                    for k in range(corewidthAdd1):
                        Sample[set_pad_size_sub1 - k, 0, :] = data_pad[lable_indexes[0, i] + set_pad_width - k,lable_indexes[1, i]  - set_pad_width + k, :]
                    # block8 LM 0 -
                    for k in range(corewidthAdd1):
                        Sample[set_pad_size_sub1 - k - corewidthAdd1, 0, :] = data_pad[lable_indexes[0, i], lable_indexes[1, i] - set_pad_width + k, :]
                else:
                    # block1 LU - -
                    for k in range(detPW):
                        Sample[0, k, :] = data_pad[lable_indexes[0, i] - set_pad_width + k,
                                          lable_indexes[1, i] - set_pad_width + k, :]
                    # = +
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[0, k+detPW, :] = data_pad[lable_indexes[0, i] - set_pad_width,
                                          lable_indexes[1, i] - set_pad_width + k + 1, :]
                    # block2 MU - 0
                    for k in range(detPW):
                        Sample[0, k + corewidthAdd1, :] = data_pad[lable_indexes[0, i] - set_pad_width + k,
                                                          lable_indexes[1, i], :]
                    # = +
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[0, k+detPW + corewidthAdd1, :] = data_pad[lable_indexes[0, i] - set_pad_width,
                                                          lable_indexes[1, i]+ 2 + k, :]
                    # block3 RU - +
                    for k in range(detPW):
                        Sample[k, set_pad_size_sub1, :] = data_pad[lable_indexes[0, i] - set_pad_width + k,
                                                          lable_indexes[1, i] + set_pad_width - k, :]
                    # + =
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[k+detPW, set_pad_size_sub1, :] = data_pad[lable_indexes[0, i] - set_pad_width + k + 1,
                                                          lable_indexes[1, i] + set_pad_width, :]
                    # block4 RM 0 +
                    for k in range(detPW):
                        Sample[k + corewidthAdd1, set_pad_size_sub1, :] = data_pad[lable_indexes[0, i],
                                                                          lable_indexes[1, i] + set_pad_width - k, :]
                    # + =
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[k+detPW + corewidthAdd1, set_pad_size_sub1, :] = data_pad[lable_indexes[0, i]+ k + 1,
                                                                          lable_indexes[1, i] + set_pad_width, :]
                    # block5 RD + +
                    for k in range(detPW):
                        Sample[set_pad_size_sub1, set_pad_size_sub1 - k, :] = data_pad[
                                                                              lable_indexes[0, i] + set_pad_width - k,
                                                                              lable_indexes[1, i] + set_pad_width - k, :]
                    # = -
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[set_pad_size_sub1, set_pad_size_sub1 - k -detPW, :] = data_pad[
                                                                              lable_indexes[0, i] + set_pad_width,
                                                                              lable_indexes[1, i] + set_pad_width - k -1, :]
                    # block6 MD + 0
                    for k in range(detPW):
                        Sample[set_pad_size_sub1, set_pad_size_sub1 - k - corewidthAdd1, :] = data_pad[lable_indexes[
                                                                                                           0, i] + set_pad_width - k,
                                                                                              lable_indexes[1, i], :]
                    # = -
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[set_pad_size_sub1, set_pad_size_sub1 - k-detPW - corewidthAdd1, :] = data_pad[lable_indexes[
                                                                                                           0, i] + set_pad_width,
                                                                                              lable_indexes[1, i]- k -1, :]
                    # block7 LD + -
                    for k in range(detPW):
                        Sample[set_pad_size_sub1 - k, 0, :] = data_pad[lable_indexes[0, i] + set_pad_width - k,
                                                              lable_indexes[1, i] - set_pad_width + k, :]
                    # - =
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[set_pad_size_sub1 - k-detPW, 0, :] = data_pad[lable_indexes[0, i] + set_pad_width - k -1,
                                                              lable_indexes[1, i] - set_pad_width, :]
                    # block8 LM 0 -
                    for k in range(detPW):
                        Sample[set_pad_size_sub1 - k - corewidthAdd1, 0, :] = data_pad[lable_indexes[0, i],
                                                                              lable_indexes[1, i] - set_pad_width + k, :]
                    # - =
                    for k in range(set_corewidth + 1 - detPW):
                        Sample[set_pad_size_sub1 - k-detPW - corewidthAdd1, 0, :] = data_pad[lable_indexes[0, i] - k -1,
                                                                              lable_indexes[1, i] - set_pad_width, :]

                DataSet[i, :] = np.reshape(Sample, [1, Sample.shape[0] * Sample.shape[1] * Sample.shape[2]], order="F")
                for j in range(kaugmentation):
                    Sample = np.rot90(Sample,1,axes=(0, 1))

                    DataSet[i+(j+1)*DataLen, :] = np.reshape(Sample, [1, Sample.shape[0] * Sample.shape[1] * Sample.shape[2]], order="F")
        elif kmethod == 1:
            coreAdd1size = corewidthAdd1 * 2 + 1
            detL = (2*set_pad_width )/(2*set_corewidth +2) #set_pad_width, set_corewidth
            Moveindex = np.zeros(coreAdd1size,dtype=int)
            for i in range(coreAdd1size):
                Moveindex[i] = int(i*detL+0.5)
            for i in range(DataLen):
                # Writing Core Area Data
                xs = lable_indexes[0, i] - set_corewidth
                xe = xs + 2 * set_corewidth + 1
                ys = lable_indexes[1, i] - set_corewidth
                ye = ys + 2 * set_corewidth + 1
                Sample = data_pad[xs:xe, ys:ye, :]
                # Preparation of Sampling Data
                Sample = np.pad(Sample, ((1, 1), (1, 1), (0, 0)), 'constant')
                # print(Sample.shape)
                # Recording of Sampling Data
                # U
                for k in range(coreAdd1size):
                    Sample[0, k, :] = data_pad[lable_indexes[0, i] - set_pad_width,
                                      lable_indexes[1, i] - set_pad_width + Moveindex[k], :]
                # R
                for k in range(coreAdd1size):
                    Sample[k, set_pad_size_sub1, :] = data_pad[lable_indexes[0, i] - set_pad_width + Moveindex[k],
                                      lable_indexes[1, i] + set_pad_width, :]
                # D
                for k in range(coreAdd1size):
                    Sample[set_pad_size_sub1, k, :] = data_pad[lable_indexes[0, i] + set_pad_width,
                                      lable_indexes[1, i] - set_pad_width + Moveindex[k], :]
                # L
                for k in range(coreAdd1size):
                    Sample[k, 0, :] = data_pad[lable_indexes[0, i] - set_pad_width + Moveindex[k],
                                      lable_indexes[1, i] - set_pad_width, :]

                DataSet[i, :] = np.reshape(Sample, [1, Sample.shape[0] * Sample.shape[1] * Sample.shape[2]],
                                           order="F")
                for j in range(kaugmentation):
                    Sample = np.rot90(Sample, 1, axes=(0, 1))
                    # print(Sample[:, :, 0])
                    DataSet[i+(j+1)*DataLen, :] = np.reshape(Sample, [1, Sample.shape[0] * Sample.shape[1] * Sample.shape[2]], order="F")
    else:
        # Training Data Segmentation
        for i in range(DataLen):
            xs = lable_indexes[0 , i] - set_pad_width
            xe = xs + 2 * set_pad_width + 1
            ys = lable_indexes[1 , i] - set_pad_width
            ye = ys + 2 * set_pad_width + 1
            Sample = data_pad[xs:xe , ys:ye , :]
            DataSet[i, :] = np.reshape(Sample, [1, Sample.shape[0] * Sample.shape[1] * Sample.shape[2]], order="F")
    return DataSet, set_pad_size, z_datapad


