import torch as t
import numpy as np
import pickle
import hdf5storage


class DataSetTool(object):
    @staticmethod
    def Load_X_y_FromFile(data_set_file, data_set_information=None):
        """

        :param data_set_file: can be None
        :param data_set_information: can be None for 'vector' __data set
        :return:
        """
        X = None
        y = None
        if data_set_file is not None:
            if data_set_information is None:
                data_set_information = {'type': 'vector'}
            extension_name = str(data_set_file)
            index = len(extension_name) - 1
            while extension_name[index] != '.':
                index = index - 1
            extension_name = extension_name[index: len(extension_name)]
            if extension_name == '.npy':
                file_data = np.load(file=data_set_file, allow_pickle=True)
                y = np.int32(file_data[:, len(file_data[0]) - 1])
                X = file_data[:, 0: len(file_data[0]) - 1]
            elif extension_name == '.txt':
                file_data = np.loadtxt(fname=data_set_file, dtype=np.float64, delimiter=',')
                y = np.int32(file_data[:, len(file_data[0]) - 1])
                X = file_data[:, 0: len(file_data[0]) - 1]
            elif extension_name == '.mat':
                file_data = hdf5storage.loadmat(file_name=data_set_file)
                y = file_data['y']
                y = y.squeeze()
                X = file_data['X']
            else:
                raise Exception('The type of data_set_file must be .txt or .npy or .mat！!')
            y = np.int32(y)
            X = np.float64(X)
            if data_set_information['type'] == 'image':
                X = np.reshape(a=X, newshape=[y.shape[0],
                                              data_set_information['image_channel'],
                                              data_set_information['image_height'],
                                              data_set_information['image_width']])
                # for i in range(self._information['sample_num']):
                #     if self._information['image_channel'] == 1:
                #         plt.imshow(X=self.__X[i,:,:,0], cmap='gray')
                #         plt.show()
                #     else:
                #         plt.imshow(X=np.uint8(self.__X[i]))
                #         plt.show()
        return X, y

    @staticmethod
    def LoadRandSamInd(mat_file_name):
        rand_sam_ind = hdf5storage.loadmat(file_name=mat_file_name)['rand_sam_ind']
        return rand_sam_ind

    @staticmethod
    def Load_R_classFromFile(R_class_file):
        file_data = hdf5storage.loadmat(file_name=R_class_file)
        R_class = np.float64(file_data['R_class'])
        return R_class


# Python numpy based DataSet
class DataSet_python(object):
    def __init__(self, data_set_file, data_set_information=None, R_class_file=None):
        """

        :param data_set_information: dict
        {
         'type': 'image' or 'vector'
         'feature_num':
         'image_width':
         'image_height':
         'image_channel':
        }
        :param data_set_file:
        """
        super(DataSet_python, self).__init__()
        # self.__information = data_set_information
        X, y = DataSetTool.Load_X_y_FromFile(data_set_file=data_set_file, data_set_information=data_set_information)
        self.__X = X
        self.__y = y

        self.__R_class = None
        if R_class_file is not None:
            self.__R_class = DataSetTool.Load_R_classFromFile(R_class_file=R_class_file)

        self.__batch_size = None
        self.__next_batch_beg_index = None
        self.__random_sam_index_arr = None

        self.__class_space = None
        self._class_sam_table = None

    def __init_R_class(self):
        if self.__class_space is None:
            sam_num = self.get_sample_num()
            self.__R_class = np.ones(shape=[sam_num, sam_num], dtype=np.float64)
            for i in range(sam_num):
                for j in range(i + 1, sam_num):
                    if self.__y[i] != self.__y[j]:
                        self.__R_class[i, j] = 0.0
                        self.__R_class[j, i] = 0.0

    def __del__(self):
        var_name_list = list(self.__dict__.keys())
        for name in var_name_list:
            del_var_command_str = 'del self.' + name
            exec(del_var_command_str)

    " Basic Information "

    def get_data_type(self):
        if self.__X.ndim == 2:
            return 'vector'
        elif self.__X.ndim == 4:
            return 'image'
        else:
            raise Exception('Known __data type!')

    def get_feature_num(self):
        if self.get_data_type() == 'vector':
            return self.__X.shape[1]
        elif self.get_data_type() == 'image':
            return -1

    def get_image_channel_num(self):
        if self.get_data_type() == 'vector':
            return -1
        elif self.get_data_type() == 'image':
            return self.__X.shape[1]

    def get_image_high(self):
        if self.get_data_type() == 'vector':
            return -1
        elif self.get_data_type() == 'image':
            return self.__X.shape[2]

    def get_image_width(self):
        if self.get_data_type() == 'vector':
            return -1
        elif self.get_data_type() == 'image':
            return self.__X.shape[3]

    def get_sample_num(self):
        return self.__y.shape[0]

    def get_class_label(self, sample_index):
        return self.__y[sample_index]

    def get_feature(self, sample_index):
        return self.__X[sample_index, :]

    def get_feature_of_all_sample(self):
        return self.__X

    def get_feature_of_class_sample(self, class_label):
        sam_indices = np.array(range(self.__y.shape[0]))
        sam_indices = sam_indices[self.__y == class_label]
        return {'X': self.__X[sam_indices, :], 'sam_indices': sam_indices}

    def get_class_space(self):
        return np.unique(self.__y)

    def get_class_num(self):
        return np.unique(self.__y).shape[0]

    def get_class_label_of_all_sample(self):
        return self.__y

    def get_R_class(self):
        if self.__R_class is None:
            self.__init_R_class()
        return self.__R_class

    def get_data_information(self):
        return {
            'type': self.get_data_type(),
            'feature_num': self.get_feature_num(),
            'image_width': self.get_image_width(),
            'image_height': self.get_image_high(),
            'image_channel': self.get_image_channel_num()
        }

    def show_myself(self):
        print('class DataSet_python(DataSet):')
        print(self.get_data_information())
        print('X.shape =', self.__X.shape)
        print('y.shape =', self.__y.shape)

    def get_class_distribution(self):
        """

        :return:
        class_sam_num_list: list, len=class_num
        class_sam_ind_list: 2-D list, class_sam_ind[i]: list, the indices of samples in i-th class
        """
        sam_ind_bank = np.array(range(self.get_sample_num()))
        class_space = self.get_class_space()
        class_sam_num_list = []
        class_sam_ind_list = []
        for c in class_space:
            c_index = self.__y == c
            c_sam_ind = sam_ind_bank[c_index]
            class_sam_num_list.append(c_sam_ind.shape[0])
            class_sam_ind_list.append(list(c_sam_ind))
        return {'class_sam_num_list': class_sam_num_list, 'class_sam_ind_list': class_sam_ind_list}

    def image_0_255_to_Neg1_Pos1(self):
        """
        for image __data,
        convert {0, 1, ... 255} to [-1, 1]
        :return:
        """
        self.__X = self.__X - 127.5
        self.__X = self.__X / 127.5

    " Epoch and Batch "

    def init_epoch(self, epoch_num=1, batch_size=32, do_shuffle=True, do_balance=False):
        if self.__R_class is None:
            self.__init_R_class()

        self.__batch_size = batch_size
        self.__next_batch_beg_index = 0
        if do_balance:
            class_dis = self.get_class_distribution()
            c_sam_num_list = class_dis['class_sam_num_list']
            c_sam_ind_list = class_dis['class_sam_ind_list']
            max_sam_num = max(c_sam_num_list)
            class_num = len(c_sam_num_list)
            self.__random_sam_index_arr = []
            for c in range(class_num):
                c_sam_num = c_sam_num_list[c]
                c_sam_ind = c_sam_ind_list[c]
                c_sam_indices = c_sam_ind * (max_sam_num // c_sam_num) + c_sam_ind[0:(max_sam_num % c_sam_num)]
                self.__random_sam_index_arr = self.__random_sam_index_arr + c_sam_indices
            self.__random_sam_index_arr = self.__random_sam_index_arr * epoch_num
        else:
            self.__random_sam_index_arr = list(range(self.get_sample_num())) * epoch_num

        self.__random_sam_index_arr = np.array(self.__random_sam_index_arr, dtype=np.int64)
        if do_shuffle:
            np.random.shuffle(self.__random_sam_index_arr)

    def get_next_batch(self, random_distortion=None, device=None):
        """
        :return:
        """
        if self.__next_batch_beg_index != len(self.__random_sam_index_arr):
            batch_end_index = min(self.__next_batch_beg_index + self.__batch_size, len(self.__random_sam_index_arr))
            batch_index = self.__random_sam_index_arr[self.__next_batch_beg_index: batch_end_index]
            self.__next_batch_beg_index = batch_end_index
            batch_R = self.__R_class[batch_index, :]
            batch_R = batch_R[:, batch_index]
            batch_X = self.__X[batch_index]
            if random_distortion is not None:
                batch_X = t.tensor(data=batch_X, device=device)
                batch_X = random_distortion(batch_X)
                batch_X = batch_X.cpu().numpy()
            return {'index': batch_index,
                    'batch_X': batch_X,
                    'batch_y': self.__y[batch_index],
                    'batch_R': batch_R,
                    'is_last_batch': batch_end_index == len(self.__random_sam_index_arr)}
        raise Exception('There is no next batch! Please call initEpoch(***) first!!')

    def save_myself_to_binary_file(self, file_name):
        """
        :param file_name:
        :return:
        """
        bin_write_file = open(file_name, 'wb')
        bin_write_file.write(pickle.dumps(self))
        bin_write_file.close()

    def get_subset_of_myself(self, exemplar_indices):
        data_set = DataSet_python(data_set_file=None, data_set_information=None)
        data_set.__X = self.__X[exemplar_indices]
        data_set.__y = self.__y[exemplar_indices]
        return data_set

    @staticmethod
    def Union_Training_Test_Dataset(train_data_set, test_data_set):
        """Union_Training_TestDataSet
        Get the union of a training __data set and a test __data set.
        For the samples that come from the test __data set, the class label is marked as -999
        :param train_data_set: n1 samples
        :param test_data_set: n2 samples
        :return: data_set: n1 + n2 samples
        """
        data_set = DataSet_python(data_set_file=None, data_set_information=None)
        data_set.__X = np.concatenate((train_data_set.__X, test_data_set.__X), axis=0)
        data_set.__y = np.concatenate((train_data_set.__y, test_data_set.__y), axis=0)
        return data_set

    @staticmethod
    def Split_Training_Test_Dataset(data_set, rand_sam_ind, train_sam_ratio=0.75):
        """

        :param data_set:
        :param rand_sam_ind: shape=[n], a random permutation of {0, 1, 2, ..., n-1}
        :param train_sam_ratio: float,
        the first ceil(n*train_sam_ratio) samples in the rand_sam_ind are used as the training __data
        :return:
        """
        sam_num = data_set.get_sample_num()
        train_sam_num = np.int64(np.ceil(sam_num * train_sam_ratio))
        train_sam_ind = rand_sam_ind[range(train_sam_num)]
        train_data_set = DataSet_python(data_set_file=None, data_set_information=None)
        train_data_set.__X = data_set.__X[train_sam_ind]
        train_data_set.__y = data_set.__y[train_sam_ind]

        test_sam_ind = rand_sam_ind[range(train_sam_num, sam_num)]
        test_data_set = DataSet_python(data_set_file=None, data_set_information=None)
        test_data_set.__X = data_set.__X[test_sam_ind]
        test_data_set.__y = data_set.__y[test_sam_ind]

        return {'train_data_set': train_data_set, 'test_data_set': test_data_set}

    @staticmethod
    def Split_Training_Test_Dataset_Open_World(data_set, rand_sam_ind, train_sam_ratio=0.75, train_class_ratio=0.5):
        """

        :param data_set:
        :param rand_sam_ind: shape=[n], a random permutation of {0, 1, 2, ..., n-1}
        :param train_sam_ratio: float,
        the first ceil(n*train_sam_ratio) samples in the rand_sam_ind are used as the training __data
        :param train_class_ratio: float
        the first ceil(n*train_sam_ratio) samples that belong to the first ceil(c*train_sam_ratio) classes
        :return:
        """
        sam_num = data_set.get_sample_num()
        train_sam_num = np.int64(np.ceil(sam_num * train_sam_ratio))
        train_sam_ind1 = rand_sam_ind[range(train_sam_num)]
        class_space = data_set.get_class_space()
        train_class_num = np.int64(np.ceil(class_space.shape[0] * train_class_ratio))
        train_class_space = class_space[range(train_class_num)]
        train_sam_ind = []
        for i in train_sam_ind1:
            if data_set.__y[i] in train_class_space:
                train_sam_ind.append(i)
        train_sam_ind = np.array(train_sam_ind)

        train_data_set = DataSet_python(data_set_file=None, data_set_information=None)
        train_data_set.__X = data_set.__X[train_sam_ind]
        train_data_set.__y = data_set.__y[train_sam_ind]

        test_sam_ind = rand_sam_ind[range(train_sam_num, sam_num)]
        test_data_set = DataSet_python(data_set_file=None, data_set_information=None)
        test_data_set.__X = data_set.__X[test_sam_ind]
        test_data_set.__y = data_set.__y[test_sam_ind]

        return {'train_data_set': train_data_set, 'test_data_set': test_data_set}

    @staticmethod
    def Generate_random_data_set(sample_num, feature_num, class_num):
        """
        :param sample_num:
        :param feature_num:
        :param class_num:
        :return:
        """
        data_set = DataSet_python(None)
        data_set.__X = np.random.uniform(size=[sample_num, feature_num])
        data_set.__y = np.random.randint(low=0, high=class_num, size=[sample_num], dtype=np.int)
        data_set.__batch_size = None
        data_set.__next_batch_beg_index = None
        data_set.__random_sam_index_arr = None
        data_set.__information = {'type': 'matrix',
                                  'feature_num': feature_num,
                                  'image_width': 'N/A',
                                  'image_height': 'N/A',
                                  'image_channel': 'N/A'}
        return data_set

    @staticmethod
    def Load_from_binary_file(file_name):
        """
        :param file_name:
        :return:
        """
        bin_read_file = open(file_name, 'rb')
        data_set = pickle.load(bin_read_file)
        bin_read_file.close()
        return data_set
