from sklearn.preprocessing import MinMaxScaler, StandardScaler
import numpy as np
from torch.utils.data import Dataset
import torch
import hdf5storage  # 处理 .mat 格式（MATLAB 保存的矩阵数据）的 Python 库。
import os

class MultiviewData(Dataset):
    def __init__(self, db, device, path="data/"):
        self.data_views = list()  # 对应原始矩阵X

        # 根据不同数据集读取数据
        if db == "RGB-D":
            mat = hdf5storage.loadmat(os.path.join(path, 'RGB-D.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == 'CCV':
            mat = hdf5storage.loadmat(os.path.join(path, 'CCV.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)


        # elif db == 'Cora':
        #     mat = hdf5storage.loadmat(os.path.join(path, 'Cora.mat'))
        #     X_data = mat['X']
        #     self.num_views = X_data.shape[1]
        #     for idx in range(self.num_views):
        #         self.data_views.append(X_data[0, idx].astype(np.float32))
        #     scaler = MinMaxScaler()
        #     for idx in range(self.num_views):
        #         self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
        #     self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)

        elif db == 'Digit-Product':
            mat = hdf5storage.loadmat(os.path.join(path, 'Digit-Product.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = StandardScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)


        elif db == 'ALOI-100':
            mat = hdf5storage.loadmat(os.path.join(path, 'ALOI-100.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)

        elif db == 'Hdigit':
            mat = hdf5storage.loadmat(os.path.join(path, 'Hdigit.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx]. astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)

        elif db == "20newsgroups":
            mat = hdf5storage.loadmat(os.path.join(path, '20newsgroups.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "animal":
            mat = hdf5storage.loadmat(os.path.join(path, 'animal.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].T.astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['gt'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == "AWA":
            mat = hdf5storage.loadmat(os.path.join(path, 'AWA.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == "AWA2":
            mat = hdf5storage.loadmat(os.path.join(path, 'AWA2.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == "BBC":
            mat = hdf5storage.loadmat(os.path.join(path, 'BBC.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "BBCSport":
            mat = hdf5storage.loadmat(os.path.join(path, 'BBCSport.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "BDGP":
            mat = hdf5storage.loadmat(os.path.join(path, 'BDGP.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech101-7":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech101-7.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech101-20-2views":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech101-20-2views.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech101-20-6views":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech101-20-6views.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech101-all":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech101-all.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech101-7":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech101-7.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).astype(np.int32)  # 对应标签矩阵Y

        elif db == "cifar10":
            mat = hdf5storage.loadmat(os.path.join(path, 'cifar10.mat'))
            X_data = mat['data']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].T.astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            raw = mat['truelabel'][0]
            # 如果 raw 是 object 数组（即里面存放着 list/ndarray），先展平
            if raw.dtype == np.dtype('O'):
                flat = np.concatenate(raw).ravel()
            else:
                flat = np.asarray(raw).ravel()
            self.labels = flat.astype(np.int32)  # 对应标签矩阵Y

        elif db == "CiteSeer":
            mat = hdf5storage.loadmat(os.path.join(path, 'CiteSeer.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "COIL20":
            mat = hdf5storage.loadmat(os.path.join(path, 'COIL20.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Cora":
            mat = hdf5storage.loadmat(os.path.join(path, 'Cora.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "handwritten":
            mat = hdf5storage.loadmat(os.path.join(path, 'handwritten.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "HW1256":
            mat = hdf5storage.loadmat(os.path.join(path, 'HW1256.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Mfeat":
            mat = hdf5storage.loadmat(os.path.join(path, 'Mfeat.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "prokaryotic":
            mat = hdf5storage.loadmat(os.path.join(path, 'prokaryotic.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "scene":
            mat = hdf5storage.loadmat(os.path.join(path, 'scene.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "uci-digit":
            mat = hdf5storage.loadmat(os.path.join(path, 'uci-digit.mat'))
            X_data = mat['X']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Fashion":
            mat = hdf5storage.loadmat(os.path.join(path, 'Fashion.mat'))
            X_1 = mat['X1'].astype(np.float32)
            X_2 = mat['X2'].astype(np.float32)
            X_3 = mat['X3'].astype(np.float32)
            X_1_flat = X_1.reshape(X_1.shape[0], -1)
            X_2_flat = X_2.reshape(X_2.shape[0], -1)
            X_3_flat = X_3.reshape(X_3.shape[0], -1)
            self.num_views = 3  # 视图数量
            self.data_views.append(X_1_flat)
            self.data_views.append(X_2_flat)
            self.data_views.append(X_3_flat)
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Hdigit":
            mat = hdf5storage.loadmat(os.path.join(path, 'Hdigit.mat'))
            X_data = mat['data']
            self.num_views = X_data.shape[1]  # 视图数量
            for idx in range(self.num_views):
                self.data_views.append(X_data[0, idx].T.astype(np.float32))
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['truelabel'][0])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "MNIST_USPS":
            mat = hdf5storage.loadmat(os.path.join(path, 'MNIST_USPS.mat'))
            X_1 = mat['X1'].astype(np.float32)
            X_2 = mat['X2'].astype(np.float32)
            X_1_flat = X_1.reshape(X_1.shape[0], -1)
            X_2_flat = X_2.reshape(X_2.shape[0], -1)
            self.num_views = 2  # 视图数量
            self.data_views.append(X_1_flat)
            self.data_views.append(X_2_flat)
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "NGs":
            mat = hdf5storage.loadmat(os.path.join(path, 'NGS.mat'))
            X_1 = mat['X1'].astype(np.float32)
            X_2 = mat['X2'].astype(np.float32)
            X_3 = mat['X3'].astype(np.float32)
            X_1_flat = X_1.reshape(X_1.shape[0], -1)
            X_2_flat = X_2.reshape(X_2.shape[0], -1)
            X_3_flat = X_3.reshape(X_3.shape[0], -1)
            self.num_views = 3
            self.data_views.append(X_1_flat)
            self.data_views.append(X_2_flat)
            self.data_views.append(X_3_flat)
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech-3V":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech-3V.mat'))
            X_1 = mat['X1'].astype(np.float32)
            X_2 = mat['X2'].astype(np.float32)
            X_5 = mat['X5'].astype(np.float32)
            X_1_flat = X_1.reshape(X_1.shape[0], -1)
            X_2_flat = X_2.reshape(X_2.shape[0], -1)
            X_5_flat = X_5.reshape(X_5.shape[0], -1)
            self.num_views = 3
            self.data_views.append(X_1_flat)
            self.data_views.append(X_2_flat)
            self.data_views.append(X_5_flat)
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech-4V":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech-3V.mat'))
            X_1 = mat['X1'].astype(np.float32)
            X_2 = mat['X2'].astype(np.float32)
            X_4 = mat['X4'].astype(np.float32)
            X_5 = mat['X5'].astype(np.float32)
            X_1_flat = X_1.reshape(X_1.shape[0], -1)
            X_2_flat = X_2.reshape(X_2.shape[0], -1)
            X_4_flat = X_4.reshape(X_4.shape[0], -1)
            X_5_flat = X_5.reshape(X_5.shape[0], -1)
            self.num_views = 4
            self.data_views.append(X_1_flat)
            self.data_views.append(X_2_flat)
            self.data_views.append(X_4_flat)
            self.data_views.append(X_5_flat)
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        elif db == "Caltech-5V":
            mat = hdf5storage.loadmat(os.path.join(path, 'Caltech-3V.mat'))
            X_1 = mat['X1'].astype(np.float32)
            X_2 = mat['X2'].astype(np.float32)
            X_3 = mat['X3'].astype(np.float32)
            X_4 = mat['X4'].astype(np.float32)
            X_5 = mat['X5'].astype(np.float32)
            X_1_flat = X_1.reshape(X_1.shape[0], -1)
            X_2_flat = X_2.reshape(X_2.shape[0], -1)
            X_3_flat = X_3.reshape(X_3.shape[0], -1)
            X_4_flat = X_4.reshape(X_4.shape[0], -1)
            X_5_flat = X_5.reshape(X_5.shape[0], -1)
            self.num_views = 5
            self.data_views.append(X_1_flat)
            self.data_views.append(X_2_flat)
            self.data_views.append(X_3_flat)
            self.data_views.append(X_4_flat)
            self.data_views.append(X_5_flat)
            scaler = MinMaxScaler()
            for idx in range(self.num_views):
                self.data_views[idx] = scaler.fit_transform(self.data_views[idx])
            # np.squeeze从数组的形状中删除单维度条目，即把shape中为1的维度去掉
            self.labels = np.array(np.squeeze(mat['Y'])).T.astype(np.int32)  # 对应标签矩阵Y

        else:
            raise NotImplementedError

        for idx in range(self.num_views):
            # torch.from_numpy将 NumPy 数组转换为 PyTorch 张量
            # to(device)移动数据到 cuda 或 cpu
            self.data_views[idx] = torch.from_numpy(self.data_views[idx]).to(device)

    def __len__(self):
        return len(self.labels)  # 对应n

    def __getitem__(self, index):
        sub_data_views = list()
        for view_idx in range(self.num_views):
            data_view = self.data_views[view_idx]
            sub_data_views.append(data_view[index])

        return sub_data_views, self.labels[index], index


def get_multiview_data(mv_data, batch_size):
    num_views = len(mv_data.data_views)
    num_samples = len(mv_data.labels)
    num_clusters = len(np.unique(mv_data.labels))

    mv_data_loader = torch.utils.data.DataLoader(
        mv_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
    )

    return mv_data_loader, num_views, num_samples, num_clusters


def get_all_multiview_data(mv_data):
    num_views = len(mv_data.data_views)
    num_samples = len(mv_data.labels)
    num_clusters = len(np.unique(mv_data.labels))

    mv_data_loader = torch.utils.data.DataLoader(
        mv_data,
        batch_size=num_samples,
        shuffle=False,
        drop_last=False,
    )

    return mv_data_loader, num_views, num_samples, num_clusters