import time
import numpy as np
import torch as t
from torch import nn, Tensor, optim
from Core.BasicModules import FullConnectNetwork, ConvolutionNeuralNetwork, \
    ModelM3, ModelM5, ModelM7, \
    BinaryFuzzyRelationNetwork, FuzzyPermissibleLoss, L2NormRegularization
from Core.Tool import Tool, LossQueue


class NN_FLM(nn.Module):
    """
    Neural Network based Fuzzy Learning Machine
    B: Binary
    F: Fuzzy
    R: Reflexive
    S: Symmetric
    T: Transitive
    P: Permissible
    L: Loss
    fea: Feature
    ext: Extract
    """

    def __init__(self, sample_type_shape, fea_ext_net_type: str, fea_ext_net_structure,
                 fuzzy_para_alpha=0.05, fuzzy_para_beta=0.1,
                 gamma_FPL=1.0, gamma_Reg=0.1,
                 model_design_file=None):
        """
        Construction method\n
        :param sample_type_shape: sample's type and shape, dict
               sample_type_shape['type']: string, sample's type  {'vector', 'image'}
               sample_type_shape['feature_num']: int, the number of the feature,
                   be no effect when sample_type_shape['type'] is 'image'
               sample_type_shape['high']: int, the high of the image,
                   be no effect when sample_type_shape['type'] is 'vector'
               sample_type_shape['width']: int, the width of the image,
                   be no effect when sample_type_shape['type'] is 'vector'
               sample_type_shape['channels_num']: int, the number of the channel of the image,
                   be no effect when sample_type_shape['type'] is 'vector'
        :param fea_ext_net_type: str, {'FCN', 'CNN', 'CNN-FCN', 'None', 'M3', 'M5', 'M7'}
        :param fea_ext_net_structure: the structure of the feature extraction network
                   list, [layer1-structure, layer2-structure, ...], when fea_ext_net_type='FCN' or 'CNN'
                   2-D list, [CNN-structure-list, FCN-structure-list] ,when fea_ext_net_type='CNN-FCN'
                   be no effect when fea_ext_net_type='None'
        :param gamma_FPL: float, non-negative, hyper-parameter
        :param gamma_Reg: float, non-negative, hyper-parameter
        :param model_design_file: string, file name, print the structure information to the txt file
        """
        super(NN_FLM, self).__init__()
        self.__sample_type_shape = sample_type_shape
        self.__fea_ext_net_type = fea_ext_net_type
        self.__fea_ext_net_structure = fea_ext_net_structure
        self.__fea_ext_net = None
        self.__hidden_features = -1
        self.__build_feature_extract_network()
        self.__BFR_net = BinaryFuzzyRelationNetwork()

        self.__fuzzy_para_alpha = fuzzy_para_alpha
        self.__fuzzy_para_beta = fuzzy_para_beta
        self.__FP_loss = FuzzyPermissibleLoss(alpha=self.__fuzzy_para_alpha,
                                              beta=self.__fuzzy_para_beta)
        self.__L2_regular = L2NormRegularization()  # 0.5 and mean in it
        self.__gamma_FPL = gamma_FPL
        self.__gamma_fea_ext_net_l2reg = gamma_Reg

        self.__optimizer = None
        self.build_optimizer()

        if model_design_file is not None:
            # self.print_model_design()
            self.save_model_design(file_name=model_design_file)

    def __build_feature_extract_network(self):
        if self.__fea_ext_net_type == 'FCN':
            self.__fea_ext_net = FullConnectNetwork(in_features=self.__sample_type_shape['feature_num'],
                                                    layer_structure_list=self.__fea_ext_net_structure,
                                                    network_name='Fea-Ext-FCN')
            self.__hidden_features = self.__fea_ext_net_structure[len(self.__fea_ext_net_structure) - 1]['node_num']
        elif self.__fea_ext_net_type == 'CNN':
            self.__fea_ext_net = ConvolutionNeuralNetwork(in_high=self.__sample_type_shape['high'],
                                                          in_width=self.__sample_type_shape['width'],
                                                          in_channels=self.__sample_type_shape['channels_num'],
                                                          layer_structure_list=self.__fea_ext_net_structure,
                                                          network_name='Fea-Ext-CNN')
            cnn_output_shape = self.__fea_ext_net.get_output_shape()
            self.__hidden_features = cnn_output_shape['high'] * cnn_output_shape['width'] * cnn_output_shape[
                'channels_num']
        elif self.__fea_ext_net_type == 'CNN-FCN':
            cnn_structure_list = self.__fea_ext_net_structure[0]
            self.__fea_ext_cnn_net = ConvolutionNeuralNetwork(in_high=self.__sample_type_shape['image_height'],
                                                              in_width=self.__sample_type_shape['image_width'],
                                                              in_channels=self.__sample_type_shape['image_channel'],
                                                              layer_structure_list=cnn_structure_list,
                                                              network_name='Fea-Ext-CNN')
            cnn_output_shape = self.__fea_ext_cnn_net.get_output_shape()
            cnn_output_len = cnn_output_shape['high'] * cnn_output_shape['width'] * cnn_output_shape['channels_num']
            print('cnn_output_len=' + str(cnn_output_len))

            fcn_structure_list = self.__fea_ext_net_structure[1]
            self.__fea_ext_fcn_net = FullConnectNetwork(in_features=cnn_output_len,
                                                        layer_structure_list=fcn_structure_list,
                                                        network_name='Fea-Ext-FCN')
            self.__hidden_features = fcn_structure_list[len(fcn_structure_list) - 1]['node_num']
        elif self.__fea_ext_net_type == 'M3':
            self.__fea_ext_net = ModelM3()
        elif self.__fea_ext_net_type == 'M5':
            self.__fea_ext_net = ModelM5()
        elif self.__fea_ext_net_type == 'M7':
            self.__fea_ext_net = ModelM7()
        elif self.__fea_ext_net_type == 'None':
            self.__fea_ext_fcn_net = None
            self.__fea_ext_cnn_net = None
            self.__hidden_features = self.__sample_type_shape['feature_num']

    def __feature_extraction_forward(self, x: Tensor) -> Tensor:
        h = None
        if self.__fea_ext_net_type == 'FCN':
            h = self.__fea_ext_net(x)  # FCN output shape=[batch_size, hidden_features]
        elif self.__fea_ext_net_type == 'CNN':
            h = self.__fea_ext_net(x)  # CNN output shape=[batch_size, channels_num, high, width]
            h = h.view(h.shape[0], -1)  # flatten, shape=[batch_size, hidden_features]
        elif self.__fea_ext_net_type == 'CNN-FCN':
            h = self.__fea_ext_cnn_net(x)  # CNN output shape=[batch_size, channels_num, high, width]
            h = h.view(h.shape[0], -1)  # flatten, shape=[batch_size, cnn_hidden_features]
            h = self.__fea_ext_fcn_net(h)  # FCN output shape=[batch_size, hidden_features]
        elif self.__fea_ext_net_type == 'M3':
            h = self.__fea_ext_net(x)
        elif self.__fea_ext_net_type == 'M5':
            h = self.__fea_ext_net(x)
        elif self.__fea_ext_net_type == 'M7':
            h = self.__fea_ext_net(x)
        elif self.__fea_ext_net_type == 'None':
            h = x.view(x.shape[0], -1)
            # shape=[batch_size, feature_num] or flatten, [batch_size, channels_num*high*width]
        # for i in range(h.shape[0]):
        #     for j in range(h.shape[1]):
        #         if t.isnan(input=h[i, j]):
        #             print('h[' + str(i) + ',' + str(j) + '] is nan.')
        return h

    def forward_X(self, X: Tensor) -> Tensor:
        """

        :param X: shape=[batch_size, in_features]
        :return: shape=[batch_size, batch_size], 0 <= S[i,j] <= 1
        """
        H = self.__feature_extraction_forward(x=X)  # shape = [batch_size, hidden_features]
        S = self.__BFR_net(X=H)  # shape = [batch_size, batch_size]
        return S

    def forward_X1_X2(self, X1: Tensor, X2: Tensor) -> Tensor:
        """

        :param X1: shape=[batch_size1, in_features]
        :param X2: shape=[batch_size2, in_features]
        :return: shape=[batch_size1, batch_size2], 0 <= S[i,j] <= 1
        """
        H1 = self.__feature_extraction_forward(x=X1)  # shape = [batch_size1, hidden_features]
        H2 = self.__feature_extraction_forward(x=X2)  # shape = [batch_size2, hidden_features]
        S = self.__BFR_net.forward_X1_X2(X1=H1, X2=H2)  # shape = [batch_size1, batch_size2]
        return S

    def set_optimizer(self, fea_ext_net_lr=1, equ_rel_net_lr=1):
        self.__optimizer = None
        self.build_optimizer(fea_ext_net_lr=fea_ext_net_lr, equ_rel_net_lr=equ_rel_net_lr)

    def set_optimizer_lr(self, lr):
        """

        :param lr:
        :return:
        """
        params = self.__optimizer.param_groups
        for i in range(len(params)):
            params[i]['lr'] = lr
        self.__optimizer = optim.AdamW(params=params)

    def set_hyper_parameter(self, gamma_relation_value_fit=1.0, gamma_fea_ext_net_l2reg=0.1):
        """

        :param gamma_relation_value_fit:
        :param gamma_fea_ext_net_l2reg:
        :return:
        """
        self.__gamma_FPL = gamma_relation_value_fit
        self.__gamma_fea_ext_net_l2reg = gamma_fea_ext_net_l2reg

    def set_fuzzy_parameter(self, alpha=0.05, beta=0.95):
        """
        0 < alpha < 0.5 < beta < 1
        :param alpha:
        :param beta:
        :return:
        """
        self.__fuzzy_para_alpha = alpha
        self.__fuzzy_para_beta = beta
        self.__FP_loss = FuzzyPermissibleLoss(alpha=self.__fuzzy_para_alpha, beta=self.__fuzzy_para_beta)

    def build_optimizer(self, fea_ext_net_lr=1e-3, equ_rel_net_lr=1e-3):
        """
        optimizer {Theta, W_same, W_diff, b_same, b_diff}
        :param fea_ext_net_lr: learning rate for Theta
        :param equ_rel_net_lr:  learning rate for {W_same, W_diff, b_same, b_diff}
        :return: self.__optimizer
        """
        params = []
        for name, param in self.named_parameters():
            # print(name)
            if 'fea_ext' in name:
                params.append({'params': param, 'lr': fea_ext_net_lr})
                # if 'weight' in name:
                #     if 'Linear' in name:
                #         # print(name)
                #         params.append({'params': param, 'lr': fea_ext_net_lr})
                #     elif 'conv' in name:
                #         # print(name)
                #         params.append({'params': param, 'lr': fea_ext_net_lr})
                # elif 'bias' in name:
                #     # print(name)
                #     params.append({'params': param, 'lr': fea_ext_net_lr})
            elif 'binary_relation_net' in name:
                params.append({'params': param, 'lr': equ_rel_net_lr})

        self.__optimizer = optim.Adam(params=params, betas=(0.9, 0.999), eps=1e-8)
        # self.__optimizer = optim.SGD(params=params, lr=1)

    def decay_learning_rate(self, decay_factor=0.05):
        """
        :param decay_factor:
        :return:
        """
        params = self.__optimizer.param_groups
        for i in range(len(params)):
            params[i]['lr'] *= decay_factor

    # ================== learning process begin
    def __update_parameters(self, batch, device=t.device('cpu'), update_times=1):
        """
        F: Fuzzy
        E: Equivalence
        R: Relation
        Mat: Matrix
        update the parameters in the network\n
        :param batch: dict
        :param device:
        :return:
        """
        self.train()
        class_ER_Mat = t.tensor(data=batch['batch_R'], dtype=t.float64, device=device)
        batch_X = t.tensor(data=batch['batch_X'], dtype=t.float64, device=device)
        fea_ext_net_regular = np.nan
        fit_loss = None
        for _ in range(update_times):
            self.__optimizer.zero_grad()
            FSR_Mat = self.forward_X(X=batch_X)  # shape = [batch_size, batch_size]
            fit_loss = self.__gamma_FPL * self.__FP_loss(FR_Mat=FSR_Mat,
                                                         class_ER_Mat=class_ER_Mat)
            fea_ext_net_regular = self.__gamma_fea_ext_net_l2reg * self.__L2_regular(nn_brm=self, device=device)
            loss_regular = fit_loss + fea_ext_net_regular
            loss_regular.backward()
            self.__optimizer.step()

        if fit_loss.is_cuda:
            fit_loss = fit_loss.cpu()
            fea_ext_net_regular = fea_ext_net_regular.cpu()
        fit_loss = fit_loss.clone().detach().numpy()
        fea_ext_net_regular = fea_ext_net_regular.clone().detach().numpy()
        return {'fit_loss': fit_loss,
                'fea_ext_net_regular': fea_ext_net_regular}

    def compute_fit_loss(self, data_set, device=t.device('cpu')):
        return -1

    def train_net_with_log(self, train_data_set,
                           max_epoch_num=100, min_loss_gap=1e-5, batch_size=16,
                           decay_lr_how_often=10, decay_lr_rate=0.5,
                           data_distortion=None,
                           device=t.device('cpu'),
                           train_log_file=None, test_data_set=None):
        """

        :param train_data_set:
        :param max_epoch_num:
        :param min_loss_gap:
        :param batch_size:
        :param decay_lr_rate:
        :param decay_lr_how_often:
        :param data_distortion:
        :param device:
        :param train_log_file:
        :param test_data_set:
        :return:
        """
        self.train()
        loss = self.compute_fit_loss(data_set=train_data_set, device=device)
        best_test_ACC = -1
        test_acc = -1
        if test_data_set is not None:
            exemplar = self.select_class_exemplar(train_data_set=train_data_set,
                                                  class_exemplar_num=5,
                                                  device=device)
            pre_result = self.predict(exemplar_data_set=exemplar['exemplar_data_set'], test_data_set=test_data_set,
                                      device=device)
            test_acc = pre_result['acc']
            if test_acc > best_test_ACC:
                best_test_ACC = test_acc
                NN_FLM.SavePredictResult(result=pre_result,
                                         file_name=train_log_file + '-best-predict-result.txt')
                self.save_model(file_name=train_log_file + 'best-model.model')
        f_writer = None
        if train_log_file is not None:
            f_writer = open(file=train_log_file + '.txt', mode='w')
        if f_writer is not None:
            f_writer.write(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                           + '\tinit loss=\t' + str(loss)
                           + '\ttest ACC=\t' + str(test_acc) + '\n')
            f_writer.flush()

        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
              + ' init loss=' + str(loss)
              + ' test ACC=' + str(test_acc))

        epoch_index = 0
        while epoch_index < max_epoch_num:
            train_data_set.init_epoch(batch_size=batch_size)
            num_sam_pair = 0
            RV_loss_arr = []
            while True:
                # print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
                batch = train_data_set.get_next_batch(random_distortion=data_distortion, device=device)
                batch_loss = self.__update_parameters(batch=batch, device=device)
                batch_sam_pair_num = batch['index'].shape[0] * batch['index'].shape[0]
                num_sam_pair += batch_sam_pair_num
                RV_loss_arr.append(batch_loss['fit_loss'] * batch_sam_pair_num )
                if batch['is_last_batch']:
                    break

            # The approximate value of FuzzyPermissibleLoss on all the training samples
            RV_loss = np.sum(np.array(RV_loss_arr)) / num_sam_pair
            epoch_index += 1
            if epoch_index % decay_lr_how_often == 0:
                self.decay_learning_rate(decay_factor=decay_lr_rate)
            test_acc = -1
            if test_data_set is not None:
                exemplar = self.select_class_exemplar(train_data_set=train_data_set,
                                                      class_exemplar_num=5,
                                                      device=device)
                pre_result = self.predict(exemplar_data_set=exemplar['exemplar_data_set'], test_data_set=test_data_set,
                                          device=device)
                test_acc = pre_result['acc']
                if test_acc > best_test_ACC:
                    best_test_ACC = test_acc
                    NN_FLM.SavePredictResult(result=pre_result,
                                             file_name=train_log_file + '-best-predict-result.txt')
                    self.save_model(file_name=train_log_file + 'best-model.model')
                    print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                          + ' epoch ' + str(epoch_index) + ' end'
                          + ' train-mean-FP-loss=' + str(RV_loss)
                          + ' test ACC=' + str(test_acc))
            if f_writer is not None:
                f_writer.write(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                               + '\tepoch\t' + str(epoch_index)
                               + '\tend\ttrain-mean-RV-loss=\t' + str(RV_loss)
                               + '\ttest ACC=\t' + str(test_acc) + '\n')
                f_writer.flush()

            if best_test_ACC == 1:
                break
        f_writer.close()

    def train_net(self, train_data_set,
                  max_epoch_num=100, min_loss_gap=1e-10, batch_size=16,
                  decay_lr_how_often=10, decay_lr_rate=0.5,
                  loss_gap_scope=10,
                  data_distortion=None,
                  device=t.device('cpu')):
        """

        :param train_data_set:
        :param max_epoch_num:
        :param min_loss_gap:
        :param batch_size:
        :param decay_lr_rate:
        :param decay_lr_how_often:
        :param loss_gap_scope: the loss values of loss_gap_scope consecutive 778 epochs do not change significantly
        :param data_distortion:
        :param device:
        :return:
        """
        self.train()
        loss = self.compute_fit_loss(data_set=train_data_set, device=device)
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + ' init loss=' + str(loss))
        epoch_index = 0
        loss_queue = LossQueue(length=loss_gap_scope)
        while epoch_index < max_epoch_num and loss_queue.Max_minus_Min() > min_loss_gap:
            train_data_set.init_epoch(batch_size=batch_size)
            num_sam_pair = 0
            RV_loss_arr = []
            while True:
                batch = train_data_set.get_next_batch(random_distortion=data_distortion, device=device)
                batch_loss = self.__update_parameters(batch=batch, device=device)
                batch_sam_pair_num = batch['index'].shape[0] * batch['index'].shape[0]
                num_sam_pair += batch_sam_pair_num
                RV_loss_arr.append(batch_loss['fit_loss'] * batch_sam_pair_num)
                if batch['is_last_batch']:
                    break

            # The approximate value of FuzzyPermissibleLoss on all the training samples
            RV_loss = np.sum(np.array(RV_loss_arr)) / num_sam_pair
            loss_queue.UpdateQueue(new_ele=RV_loss)
            epoch_index += 1
            if epoch_index % decay_lr_how_often == 0:
                self.decay_learning_rate(decay_factor=decay_lr_rate)

            print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                  + ' epoch ' + str(epoch_index) + ' end'
                  + ' train-mean-FP-loss=' + str(RV_loss))

    def select_class_exemplar(self, train_data_set, class_exemplar_num=3, device=t.device('cpu')):
        """

        :param train_data_set:
        :param class_exemplar_num:
        :param device:
        :return:
        """
        class_space = train_data_set.get_class_space()
        exemplar_indices = np.empty(shape=[0], dtype=np.int64)
        for c in class_space:
            c_X_sam_ind = train_data_set.get_feature_of_class_sample(class_label=c)
            S_c = self.compute_fuzzy_relation_matrix(X1=c_X_sam_ind['X'], X2=c_X_sam_ind['X'], device=device)
            c_sam_ind = c_X_sam_ind['sam_indices']
            c_sam_score = np.sum(a=S_c, axis=1)  # the bigger the better
            c_sort_ind = np.argsort(a=-c_sam_score)  # asc, the smaller the better
            c_exemplar_num = class_exemplar_num
            if c_exemplar_num > c_sam_ind.shape[0]:
                c_exemplar_num = c_sam_ind.shape[0]
            c_exemplar_ind = c_sort_ind[range(c_exemplar_num)]
            c_exemplar_ind = c_sam_ind[c_exemplar_ind]
            exemplar_indices = np.concatenate((exemplar_indices, c_exemplar_ind))
        # X_exemplar = X[exemplar_indices, :]
        # y_exemplar = y[exemplar_indices]
        exemplar_data_set = train_data_set.get_subset_of_myself(exemplar_indices)
        return {'exemplar_data_set': exemplar_data_set, 'exemplar_sam_indices': exemplar_indices}

    # ================== predicting process of the close-world classification problem begin
    def compute_fuzzy_relation_matrix(self, X1, X2, block_size=10000, device=t.device('cpu')):
        """

        :param X1: shape=[n1, *], n1 may be very large
        :param X2: shape=[n2, *], n1 may be very large
        :param block_size:
        :param device:
        :return: shape=[n1, n2]
        """
        I1 = NN_FLM.__GenerateBlockIndices(n=X1.shape[0], block_size=block_size)
        I2 = NN_FLM.__GenerateBlockIndices(n=X2.shape[0], block_size=block_size)
        FR = np.empty(shape=[0, X2.shape[0]], dtype=np.float64)
        self.eval()
        with t.no_grad():
            for ind1 in I1:
                x1 = t.tensor(data=X1[ind1], dtype=t.float64, device=device)
                h1 = self.__feature_extraction_forward(x=x1)  # shape = [block_size, d_h]
                FR_row = np.empty(shape=[x1.shape[0], 0], dtype=np.float64)
                for ind2 in I2:
                    x2 = t.tensor(data=X2[ind2], dtype=t.float64, device=device)
                    h2 = self.__feature_extraction_forward(x=x2)  # shape = [block_size, d_h]
                    R_12 = self.__BFR_net.forward_X1_X2(X1=h1, X2=h2)  # shape = [block_size, block_size]
                    if R_12.is_cuda:
                        R_12 = R_12.cpu()
                    R_12 = R_12.numpy()
                    FR_row = np.concatenate((FR_row, R_12), axis=1)
                FR = np.concatenate((FR, FR_row), axis=0)
        return FR

    def compute_sample_class_score(self, data_set_train, data_set_test, device=t.device('cpu'), batch_size=5000):
        """
        sample-class score\n
        :param data_set_train: n1 samples, d features, c classes
        :param data_set_test: n2 samples, d features, ? classes
        :param device:
        :param batch_size:
        :return: {'sam_class_score': [0,1]^{n2 x c}, 'R_test_train': [0, 1]^{n2, n1}}
        """
        R_test_train = self.compute_fuzzy_relation_matrix(X1=data_set_test.get_feature_of_all_sample(),
                                                          X2=data_set_train.get_feature_of_all_sample(),
                                                          block_size=batch_size, device=device)
        R_test_train = np.mat(R_test_train)  # shape = [n2, n1]
        class_labels = data_set_train.get_class_label_of_all_sample()  # [n1]
        result = Tool.ClassLabels2OneHotMatrix(class_labels)
        one_hot_mat = np.mat(result['one-hot-mat'].astype(np.float64))  # shape = [n1, c]
        D = np.mat(np.diag(1 / result['class sample num'].astype(np.float64)))  # shape = [c, c]
        sam_class_score = np.matmul(R_test_train, one_hot_mat)  # shape = [n2, c]
        sam_class_score = np.matmul(sam_class_score, D)  # shape=[n2, c]
        return {'sam_class_score': np.array(sam_class_score), 'R_test_train': np.array(R_test_train)}

    def compute_sample_class_score_fuzzy(self, X_exemplar, class_score_exemplar,
                                         data_set_test,
                                         device=t.device('cpu'), batch_size=10000):
        """
        sample-class score\n
        :param X_exemplar: np.array, shape=[n, d]
        :param class_score_exemplar: np.array, shape=[n, c]
        :param data_set_test: m samples, d features, ? classes
        :param device:
        :param batch_size:
        :return: {'sam_class_score': [0,1]^{m x c}, 'R_test_train': [0, 1]^{m, n}}
        """
        FR_te_tr = self.compute_fuzzy_relation_matrix(X1=data_set_test.get_feature_of_all_sample(),
                                                      X2=X_exemplar,
                                                      block_size=batch_size, device=device)
        D = np.mat(np.diag(1 / np.sum(a=FR_te_tr, axis=1)))  # shape = [m, m]
        FR_te_tr = np.mat(FR_te_tr)  # shape = [m, n]
        sam_class_score = np.matmul(D, np.matmul(FR_te_tr, class_score_exemplar))  # shape = [m, c]
        return {'sam_class_score': np.array(sam_class_score), 'R_test_train': np.array(FR_te_tr)}

    def predictTopK(self, exemplar_data_set, test_data_set, device=t.device('cpu')):
        sam_class_score = self.compute_sample_class_score(data_set_train=exemplar_data_set,
                                                          data_set_test=test_data_set,
                                                          device=device)['sam_class_score']
        class_space = np.unique(exemplar_data_set.get_class_label_of_all_sample())
        pre_class_sort_index = np.argsort(a=sam_class_score, axis=1)
        pre_topK_labels = []
        m = sam_class_score.shape[0]
        for i in range(m):
            sam_i_topK_label = class_space[pre_class_sort_index[i, :]].tolist()
            sam_i_topK_label.reverse()
            pre_topK_labels.append(sam_i_topK_label)
        pre_topK_labels = np.array(pre_topK_labels)
        return {'pre_topK_class_labels': pre_topK_labels,
                'true_class_labels': test_data_set.get_class_label_of_all_sample()}

    def predict(self, exemplar_data_set, test_data_set, device=t.device('cpu')):
        """

        :param exemplar_data_set: n1 samples, d features, c classes
        :param test_data_set: n2 samples, d features, c classes
        :param device:
        :return: 'acc', 'pre_class_labels': [n2]
        """
        self.eval()
        with t.no_grad():
            sam_class_score = self.compute_sample_class_score(data_set_train=exemplar_data_set,
                                                              data_set_test=test_data_set,
                                                              device=device)['sam_class_score']
            class_space = np.unique(exemplar_data_set.get_class_label_of_all_sample())
            pre_class_index = np.argmax(a=np.array(sam_class_score), axis=1)  # shape=[n2]
            pre_class_labels = class_space[pre_class_index]  # shape=[n2]
            true_class_labels = test_data_set.get_class_label_of_all_sample()
            acc = np.mean(np.equal(pre_class_labels, true_class_labels).astype(np.float))
        return {'acc': acc, 'pre_class_labels': pre_class_labels, 'true_class_labels': true_class_labels}

    # ================== save and print
    def write_params_2_file(self, file_name):
        """

        :param file_name:
        :return:
        """
        f_writer = open(file=file_name, mode='w')
        index = 1
        for name, param in self.named_parameters():
            f_writer.write('para ' + str(index) + '\nname=')
            f_writer.write(name)
            f_writer.write('\n')

            if param.is_cuda:
                param = param.cpu()
            param = param.detach().numpy()

            f_writer.write('type=')
            f_writer.write(str(param.dtype))
            f_writer.write('\n')

            f_writer.write('shape=')
            f_writer.write(str(param.shape))
            f_writer.write('\n')

            f_writer.write('value=')
            f_writer.write(str(param))
            f_writer.write('\n')
            f_writer.write('\n')
            f_writer.write('\n')
            index += 1
        f_writer.close()

    def save_model(self, file_name):
        """

        :param file_name:
        :return:
        """
        t.save(obj=self, f=file_name)

    def print_model_design(self):
        print('(1) network structure:')
        print(self)
        print('(2) network parameters:')
        index = 1
        for name, param in self.named_parameters():
            print('para ' + str(index), end='\t')
            print(str(param.shape), end='\t')
            print(name)
            index += 1
        print('(3) trade-off parameters:')
        print('gamma_FPL=\t' + str(self.__gamma_FPL))
        print('gamma_fea_ext_net_l2reg=\t' + str(self.__gamma_fea_ext_net_l2reg))

    def save_model_design(self, file_name):
        file_writer = open(file=file_name, mode='w')
        file_writer.write('(1) network structure:\n')
        print(self, file=file_writer)
        file_writer.write('\n')

        file_writer.write('(2) network parameters:\n')
        param_index = 1
        for name, param in self.named_parameters():
            file_writer.write('para ' + str(param_index) + '\t')
            file_writer.write(str(param.shape) + '\t')
            file_writer.write(name + '\n')
            param_index += 1
        file_writer.write('\n')

        file_writer.write('(3) trade-off parameters:\n')
        file_writer.write('gamma_FPL=\t' + str(self.__gamma_FPL) + '\n')
        file_writer.write('gamma_fea_ext_net_l2reg=\t' + str(self.__gamma_fea_ext_net_l2reg) + '\n')
        file_writer.write('\n')

        file_writer.write('(4)\n')
        file_writer.write('number of parameter in network=\t' + str(param_index - 1) + '\n')
        file_writer.write('number of parameter in optimizer=\t' + str(len(self.__optimizer.param_groups)) + '\n')
        file_writer.close()

    # ================== staticmethod
    @staticmethod
    def load_from_file(file_name, map_location=None):
        """

        :param map_location:
        :param file_name:
        :return:
        """
        model = t.load(f=file_name, map_location=map_location)
        return model

    @staticmethod
    def init_fn(m):
        # print(type(m))
        if type(m) == nn.modules.conv.Conv2d:
            nn.init.kaiming_normal_(tensor=m.weight, a=1e-2, mode='fan_in', nonlinearity='leaky_relu')
            # print(type(m))
        elif type(m) == nn.modules.linear.Linear:
            nn.init.kaiming_normal_(tensor=m.weight, a=1e-2, mode='fan_in', nonlinearity='leaky_relu')
            # print(type(m))

    @staticmethod
    def SavePredictResult(result, file_name, information=None):
        f_writer = open(file=file_name, mode='w')
        f_writer.write('sample id\ttrue label\tpredicting label\n')
        for i in range(len(result['true_class_labels'])):
            if result['true_class_labels'][i] != result['pre_class_labels'][i]:
                f_writer.write(str(i) + '\t'
                               + str(result['true_class_labels'][i]) + '\t'
                               + str(result['pre_class_labels'][i]) + '\n')
        f_writer.write('\nacc=\t')
        f_writer.write(str(result['acc']))
        if information is not None:
            f_writer.write('\n')
            f_writer.write(information)
        f_writer.close()

    @staticmethod
    def SavePredictResult2(result, file_name, information=None):
        f_writer = open(file=file_name, mode='w')
        f_writer.write('true class labels=\t')
        for item in result['true_class_labels']:
            f_writer.write(str(item) + '\t')
        f_writer.write('\n')

        f_writer.write('predict class labels=\t')
        for item in result['pre_class_labels']:
            f_writer.write(str(item) + '\t')
        f_writer.write('\n\n')

        f_writer.write('acc=\t')
        f_writer.write(str(result['acc']))
        if information is not None:
            f_writer.write('\n')
            f_writer.write(information)
        f_writer.close()

    @staticmethod
    def SavePredictTopKResult(pre_topK_result, file_name, information=None):
        """

        :param pre_topK_result:
         {'pre_topK_class_labels': np.array(pre_topK_labels), 'true_class_labels': test_data_set.get_class_label_of_all_sample()}
        :param file_name:
        :param information:
        :return:
        """
        f_writer = open(file=file_name, mode='w')
        f_writer.write('id\ttrue-class-label\ttop1\ttop2\ttop3\t...\n')
        pre_topK_class_labels = pre_topK_result['pre_topK_class_labels']
        true_class_labels = pre_topK_result['true_class_labels']
        m = pre_topK_class_labels.shape[0]
        c = pre_topK_class_labels.shape[1]
        for i in range(m):
            if true_class_labels[i] != pre_topK_class_labels[i, 0]:
                f_writer.write(str(i) + '\t')
                f_writer.write(str(true_class_labels[i]) + '\t')
                for k in range(c):
                    f_writer.write(str(pre_topK_class_labels[i, k]) + '\t')
                f_writer.write('\n')

        if information is not None:
            f_writer.write('\n')
            f_writer.write(information)
        f_writer.close()

    @staticmethod
    def __GenerateBlockIndices(n, block_size):
        block_num = n // block_size
        m = n % block_size
        ind_2d = []
        base_indices = np.array(range(block_size))
        for i in range(block_num):
            ind_2d.append(base_indices)
            base_indices = base_indices + block_size
        if m > 0:
            ind_2d.append(base_indices[range(m)])
        return ind_2d
