import os
import random
import sys
import numpy as np
import scipy.io as sio
from scipy import sparse
from sklearn.model_selection import train_test_split
from utils import util



def load_multiview_data(config):
    data_name = config['dataset']
    main_dir = sys.path[0]
    X_list = []
    Y_list = []

    if data_name in ['Scene_15']:
        mat = sio.loadmat(os.path.join(main_dir, 'data', 'Scene_15.mat'))
        X = mat['X'][0]
        X_list.append(X[0].astype('float32'))  # 20
        X_list.append(X[1].astype('float32'))  # 59
        X_list.append(X[2].astype('float32'))  # 40
        Y_list.append(np.squeeze(mat['Y']))
        Y_list.append(np.squeeze(mat['Y']))


    elif data_name in ['100leaves']:
        mat = sio.loadmat(os.path.join(main_dir, 'data', data_name + '.mat'))
        X = mat['data']  # 读取数据，X的形状是(1, 3)
        truelabel = mat['truelabel']  # 读取标签，truelabel的形状是(1, 3)

        # 遍历3个视图
        for view in range(3):  # 注意这里是3个视图，所以range(3)
            x = X[0, view]  # 获取每个视图数据，X[0, view] 是 (64, 1600)
            x=x.T
            x = util.normalize(x).astype('float32')  # 假设util.normalize是归一化函数
            # 提取并拼接标签
            y = truelabel[0, view].flatten()  # 获取当前视图的标签，并展平为一维数组
            # 转换为整数类型
            y = y.astype('int')
            # 打印检查y的形状
            # 将处理后的数据添加到X_list和Y_list
            X_list.append(x)
            Y_list.append(y)

    elif data_name in ['NoisyMNIST']:
        data = sio.loadmat('./data/NoisyMNIST.mat')
        train = DataSet_NoisyMNIST(data['X1'], data['X2'], data['trainLabel'])
        tune = DataSet_NoisyMNIST(data['XV1'], data['XV2'], data['tuneLabel'])
        test = DataSet_NoisyMNIST(data['XTe1'], data['XTe2'], data['testLabel'])
        X_list.append(np.concatenate([tune.images1, test.images1], axis=0))
        X_list.append(np.concatenate([tune.images2, test.images2], axis=0))
        Y_list.append(np.concatenate([np.squeeze(tune.labels[:, 0]), np.squeeze(test.labels[:, 0])]))
        Y_list.append(np.concatenate([np.squeeze(tune.labels[:, 0]), np.squeeze(test.labels[:, 0])]))
    elif data_name in ['Hdigit']:
        mat = sio.loadmat(os.path.join(main_dir, 'data', 'Hdigit.mat'))
        X = mat['data']
        # 读取两个视图并归一化
        view0 = X[0][0].astype('float32').T  # (784, 10000)
        view1 = X[0][1].astype('float32').T  # (256, 10000)
        # 使用 util.normalize 进行归一化（建议每个视图都处理）
        view0 = util.normalize(view0).astype('float32')
        view1 = util.normalize(view1).astype('float32')
        # 加入列表
        X_list.append(view0)
        X_list.append(view1)
        # 标签处理
        label = np.squeeze(mat['truelabel'][0, 0]).astype('int')
        Y_list.append(label)
        Y_list.append(label)

    elif data_name in ['NoisyMNIST']:
        data = sio.loadmat('./data/NoisyMNIST.mat')
        train = DataSet_NoisyMNIST(data['X1'], data['X2'], data['trainLabel'])
        tune = DataSet_NoisyMNIST(data['XV1'], data['XV2'], data['tuneLabel'])
        test = DataSet_NoisyMNIST(data['XTe1'], data['XTe2'], data['testLabel'])
        X_list.append(np.concatenate([tune.images1, test.images1], axis=0))
        X_list.append(np.concatenate([tune.images2, test.images2], axis=0))
        Y_list.append(np.concatenate([np.squeeze(tune.labels[:, 0]), np.squeeze(test.labels[:, 0])]))
        Y_list.append(np.concatenate([np.squeeze(tune.labels[:, 0]), np.squeeze(test.labels[:, 0])]))
    elif data_name in ['MNIST_USPS']:
        data = sio.loadmat('./data/MNIST_USPS.mat')

        # 三个视图（10000, 28, 28） -> reshape成 (10000, 784) 否则不能直接输入 MLP 等模型
        X1 = data['X1'].reshape(5000, -1).astype(np.float32)
        X2 = data['X2'].reshape(5000, -1).astype(np.float32)

        # 标签是 (1, 10000) 需要 squeeze 成 (10000,)
        Y = np.squeeze(data['Y']).astype(np.int64)
        X_list.append(X1)
        X_list.append(X2)
        # 多视图任务中一般 Y_list 也放多个副本
        Y_list.append(Y)
        Y_list.append(Y)
    elif data_name in ['Fashion']:
        mat = sio.loadmat(os.path.join(main_dir, 'data', 'Fashion.mat'))
        # 提取三个视图，每个是 (10000, 28, 28)，转换为 (10000, 784) 向量
        X1 = mat['X1'].reshape(10000, -1).astype('float32')
        X2 = mat['X2'].reshape(10000, -1).astype('float32')
        X3 = mat['X3'].reshape(10000, -1).astype('float32')

        # 提取标签，(1, 10000) -> (10000,)
        y = np.squeeze(mat['Y'])

        # 添加到列表
        X_list.append(X1)
        X_list.append(X2)
        X_list.append(X3)
        Y_list.append(y)
    elif data_name in ['handwritten']:
        data = sio.loadmat('./data/handwritten.mat')

        # X 是一个包含 6 个视图的 cell array，大小为 (1, 6)
        views = data['X'][0]  # 取第 0 行作为 list（每个元素是一个样本数 × 维度的矩阵）
        Y = np.squeeze(data['Y']).astype(np.int64)  # Y 是 (1, 2000)，squeeze 后为 (2000,)

        for view in views:
            X_list.append(view.astype(np.float32))  # 每个 view 是 (2000, D_v)
            Y_list.append(Y)  # 所有视图共享同一个标签

    return X_list, Y_list


class DataSet_NoisyMNIST(object):

    def __init__(self, images1, images2, labels, fake_data=False, one_hot=False,
                 dtype=np.float32):
        """Construct a DataSet.
        one_hot arg is used only if fake_data is true.  `dtype` can be either
        `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
        `[0, 1]`.
        """
        if dtype not in (np.uint8, np.float32):
            raise TypeError('Invalid image dtype %r, expected uint8 or float32' % dtype)

        if fake_data:
            self._num_examples = 10000
            self.one_hot = one_hot
        else:
            assert images1.shape[0] == labels.shape[0], (
                    'images1.shape: %s labels.shape: %s' % (images1.shape,
                                                            labels.shape))
            assert images2.shape[0] == labels.shape[0], (
                    'images2.shape: %s labels.shape: %s' % (images2.shape,
                                                            labels.shape))
            self._num_examples = images1.shape[0]
            # Convert shape from [num examples, rows, columns, depth]
            # to [num examples, rows*columns] (assuming depth == 1)
            # assert images.shape[3] == 1
            # images = images.reshape(images.shape[0],
            #                        images.shape[1] * images.shape[2])
            if dtype == np.float32 and images1.dtype != np.float32:
                # Convert from [0, 255] -> [0.0, 1.0].
                print("type conversion view 1")
                images1 = images1.astype(np.float32)

            if dtype == np.float32 and images2.dtype != np.float32:
                print("type conversion view 2")
                images2 = images2.astype(np.float32)

        self._images1 = images1
        self._images2 = images2
        self._labels = labels
        self._epochs_completed = 0
        self._index_in_epoch = 0

    @property
    def images1(self):
        return self._images1

    @property
    def images2(self):
        return self._images2

    @property
    def labels(self):
        return self._labels

    @property
    def num_examples(self):
        return self._num_examples

    @property
    def epochs_completed(self):
        return self._epochs_completed

    def next_batch(self, batch_size, fake_data=False):
        """Return the next `batch_size` examples from this data set."""
        if fake_data:
            fake_image = [1] * 784
            if self.one_hot:
                fake_label = [1] + [0] * 9
            else:
                fake_label = 0
            return [fake_image for _ in range(batch_size)], [fake_image for _ in range(batch_size)], [fake_label for _
                                                                                                      in range(
                    batch_size)]

        start = self._index_in_epoch
        self._index_in_epoch += batch_size
        if self._index_in_epoch > self._num_examples:
            # Finished epoch
            self._epochs_completed += 1
            # Shuffle the data
            perm = np.arange(self._num_examples)
            np.random.shuffle(perm)
            self._images1 = self._images1[perm]
            self._images2 = self._images2[perm]
            self._labels = self._labels[perm]
            # Start next epoch
            start = 0
            self._index_in_epoch = batch_size
            assert batch_size <= self._num_examples

        end = self._index_in_epoch
        return self._images1[start:end], self._images2[start:end], self._labels[start:end]





