import numpy as np
import torch
from torch.utils.data import Dataset
import os
import scipy.io  # 处理mat数据


class BaseDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.data = []
        self.labels = []
        self.load_data()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return [torch.FloatTensor(x[idx]) for x in self.data], self.labels[idx]

    def load_data(self):
        raise NotImplementedError


class Caltech101_20(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['Y'].astype(np.int32).reshape(2386, )
        self.V1 = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['X'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['X'][0][3].astype(np.float32)
        self.V5 = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['X'][0][4].astype(np.float32)
        self.V6 = scipy.io.loadmat(self.path + 'Caltech101_20.mat')['X'][0][5].astype(np.float32)

        self.data1 = self.V1.reshape(2386, 48)  # x,y
        self.data2 = self.V2.reshape(2386, 40)
        self.data3 = self.V3.reshape(2386, 254)
        self.data4 = self.V4.reshape(2386, 1984)
        self.data5 = self.V5.reshape(2386, 512)
        self.data6 = self.V6.reshape(2386, 928)

        self.data = [self.data1, self.data2, self.data3, self.data4, self.data5, self.data6]
        self.labels = self.y


class Caltech101_20all(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['Y'].astype(np.int32).reshape(9144, )
        self.V1 = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['X'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['X'][0][3].astype(np.float32)
        self.V5 = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['X'][0][4].astype(np.float32)
        self.V6 = scipy.io.loadmat(self.path + 'Caltech101-all.mat')['X'][0][5].astype(np.float32)

        self.data1 = self.V1.reshape(9144, 48)  # x,y
        self.data2 = self.V2.reshape(9144, 40)
        self.data3 = self.V3.reshape(9144, 254)
        self.data4 = self.V4.reshape(9144, 1984)
        self.data5 = self.V5.reshape(9144, 512)
        self.data6 = self.V6.reshape(9144, 928)

        self.data = [self.data1, self.data2, self.data3, self.data4, self.data5, self.data6]
        self.labels = self.y


class Caltech101_2(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Caltech101_2.mat')['Y'].astype(np.int32).reshape(2386, )
        self.V1 = scipy.io.loadmat(self.path + 'Caltech101_2.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Caltech101_2.mat')['X'][0][1].astype(np.float32)
        self.data1 = self.V1.reshape(2386, 1984)  # x,y
        self.data2 = self.V2.reshape(2386, 512)
        self.data = [self.data1, self.data2]
        self.labels = self.y


class Scene_15(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Scene_15.mat')['Y'].astype(np.int32).reshape(4485, )
        self.V1 = scipy.io.loadmat(self.path + 'Scene_15.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Scene_15.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Scene_15.mat')['X'][0][2].astype(np.float32)

        self.data1 = self.V1.reshape(4485, 20)  # x,y
        self.data2 = self.V2.reshape(4485, 59)
        self.data3 = self.V3.reshape(4485, 40)

        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y


class NGs(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'NGs.mat')['truelabel'].astype(np.int32).reshape(500, )
        self.V1 = scipy.io.loadmat(self.path + 'NGs.mat')['data'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'NGs.mat')['data'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'NGs.mat')['data'][0][2].astype(np.float32)

        self.data1 = self.V1.reshape(2000, 500)  # x,y
        self.data2 = self.V2.reshape(2000, 500)
        self.data3 = self.V3.reshape(2000, 500)

        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y


class MSRCv1(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'MSRC-v1.mat')['Y'].astype(np.int32).reshape(210, )
        self.V1 = scipy.io.loadmat(self.path + 'MSRC-v1.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'MSRC-v1.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'MSRC-v1.mat')['X'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'MSRC-v1.mat')['X'][0][3].astype(np.float32)

        self.data1 = self.V1.reshape(210, 24)  # x,y
        self.data2 = self.V2.reshape(210, 512)
        self.data3 = self.V3.reshape(210, 256)
        self.data4 = self.V4.reshape(210, 254)

        self.data = [self.data1, self.data2, self.data3, self.data4]
        self.labels = self.y
class MSRC_v1(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'MSRCv1.mat')['Y'].astype(np.int32).reshape(210, )
        self.V1 = scipy.io.loadmat(self.path + 'MSRCv1.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'MSRCv1.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'MSRCv1.mat')['X'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'MSRCv1.mat')['X'][0][3].astype(np.float32)
        self.V5 = scipy.io.loadmat(self.path + 'MSRCv1.mat')['X'][0][4].astype(np.float32)


        self.data1 = self.V1.reshape(210, 24)  # x,y
        self.data2 = self.V2.reshape(210, 576)
        self.data3 = self.V3.reshape(210, 512)
        self.data4 = self.V4.reshape(210, 256)
        self.data5 = self.V5.reshape(210, 254)

        self.data = [self.data1, self.data2, self.data3, self.data4,self.data5]
        self.labels = self.y


class Reuters(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Reuters.mat')['gt'].astype(np.int32).reshape(1200, )
        self.V1 = scipy.io.loadmat(self.path + 'Reuters.mat')['fea'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Reuters.mat')['fea'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Reuters.mat')['fea'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'Reuters.mat')['fea'][0][3].astype(np.float32)
        self.V5 = scipy.io.loadmat(self.path + 'Reuters.mat')['fea'][0][4].astype(np.float32)
        self.data1 = self.V1.reshape(1200, 2000)  # x,y
        self.data2 = self.V2.reshape(1200, 2000)
        self.data3 = self.V3.reshape(1200, 2000)
        self.data4 = self.V4.reshape(1200, 2000)
        self.data5 = self.V5.reshape(1200, 2000)

        self.data = [self.data1, self.data2, self.data3, self.data4, self.data5]
        self.labels = self.y


class ORL(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'ORL.mat')['gt'].astype(np.int32).reshape(400, )
        self.V1 = scipy.io.loadmat(self.path + 'ORL.mat')['fea'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'ORL.mat')['fea'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'ORL.mat')['fea'][0][2].astype(np.float32)

        self.data1 = self.V1.reshape(400, 4096)  # x,y
        self.data2 = self.V2.reshape(400, 3304)
        self.data3 = self.V3.reshape(400, 6750)

        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y


class BBCSport(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'BBCSport.mat')['label'].astype(np.int32).reshape(544, )
        self.V1 = scipy.io.loadmat(self.path + 'BBCSport.mat')['data'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'BBCSport.mat')['data'][0][1].astype(np.float32)
        self.data1 = self.V1.reshape(544, 3183)  # x,y
        self.data2 = self.V2.reshape(544, 3203)
        self.data = [self.data1, self.data2]
        self.labels = self.y


class HW2(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'HW-2.mat')['truelabel'].astype(np.int32).reshape(2000, )
        self.V1 = scipy.io.loadmat(self.path + 'HW-2.mat')['data'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'HW-2.mat')['data'][0][1].astype(np.float32)
        self.data1 = self.V1.reshape(2000, 784)  # x,y
        self.data2 = self.V2.reshape(2000, 256)
        self.data = [self.data1, self.data2]
        self.labels = self.y


class HWall(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'HW-all.mat')['truelabel'].astype(np.int32).reshape(2000, )
        self.V1 = scipy.io.loadmat(self.path + 'HW-all.mat')['data'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'HW-all.mat')['data'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'HW-all.mat')['data'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'HW-all.mat')['data'][0][3].astype(np.float32)
        self.V5 = scipy.io.loadmat(self.path + 'HW-all.mat')['data'][0][4].astype(np.float32)
        self.V6 = scipy.io.loadmat(self.path + 'HW-all.mat')['data'][0][5].astype(np.float32)
        self.data1 = self.V1.reshape(2000, 216)  # x,y
        self.data2 = self.V2.reshape(2000, 76)
        self.data3 = self.V3.reshape(2000, 64)
        self.data4 = self.V4.reshape(2000, 6)
        self.data5 = self.V5.reshape(2000, 240)
        self.data6 = self.V6.reshape(2000, 47)
        self.data = [self.data1, self.data2, self.data3, self.data4, self.data5, self.data6]
        self.labels = self.y


class NUS_WIDE(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'NUS_WIDE.mat')['gt'].astype(np.int32).reshape(2000, )
        self.V1 = scipy.io.loadmat(self.path + 'NUS_WIDE.mat')['fea'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'NUS_WIDE.mat')['fea'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'NUS_WIDE.mat')['fea'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'NUS_WIDE.mat')['fea'][0][3].astype(np.float32)
        self.V5 = scipy.io.loadmat(self.path + 'NUS_WIDE.mat')['fea'][0][4].astype(np.float32)
        self.data1 = self.V1.reshape(2000, 65)  # x,y
        self.data2 = self.V2.reshape(2000, 226)
        self.data3 = self.V3.reshape(2000, 145)
        self.data4 = self.V4.reshape(2000, 74)
        self.data5 = self.V5.reshape(2000, 129)
        self.data = [self.data1, self.data2, self.data3, self.data4, self.data5]
        self.labels = self.y

class ALOI_100_2(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'ALOI_100_2v.mat')['gt'].astype(np.int32).reshape(10800, )
        self.V1 = scipy.io.loadmat(self.path + 'ALOI_100_2v.mat')['fea'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'ALOI_100_2v.mat')['fea'][0][1].astype(np.float32)
        self.data1 = self.V1.reshape(10800, 77)  # x,y
        self.data2 = self.V2.reshape(10800, 64)
        self.data = [self.data1, self.data2]
        self.labels = self.y
class ALOI_100(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'ALOI_100.mat')['gt'].astype(np.int32).reshape(10800, )
        self.V1 = scipy.io.loadmat(self.path + 'ALOI_100.mat')['fea'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'ALOI_100.mat')['fea'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'ALOI_100.mat')['fea'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'ALOI_100.mat')['fea'][0][3].astype(np.float32)
        self.data1 = self.V1.reshape(10800, 77)  # x,y
        self.data2 = self.V2.reshape(10800, 13)
        self.data3 = self.V3.reshape(10800, 64)  # x,y
        self.data4 = self.V4.reshape(10800, 125)
        self.data = [self.data1, self.data2]
        self.labels = self.y
class Landuse_21(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Landuse_21.mat')['gt'].astype(np.int32).reshape(2100, )
        self.V1 = scipy.io.loadmat(self.path + 'Landuse_21.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Landuse_21.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Landuse_21.mat')['X'][0][2].astype(np.float32)
        self.data1 = self.V1.reshape(2100, 20)  # x,y
        self.data2 = self.V2.reshape(2100, 59)
        self.data3 = self.V3.reshape(2100, 59)
        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y
class Yale_3(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path +  'Yale3.mat')['gt'].astype(np.int32).reshape(165, )
        self.V1 = scipy.io.loadmat(self.path + 'Yale3.mat')['fea'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Yale3.mat')['fea'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Yale3.mat')['fea'][0][2].astype(np.float32)

        self.data1 = self.V1.reshape(165, 4096)  # x,y
        self.data2 = self.V2.reshape(165, 3304)
        self.data3 = self.V3.reshape(165, 6750)
        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y
class wikipedia(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path +  'wikipedia.mat')['Y'].astype(np.int32).reshape(693, )
        self.V1 = scipy.io.loadmat(self.path + 'wikipedia.mat')['data'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'wikipedia.mat')['data'][0][1].astype(np.float32)


        self.data1 = self.V1.reshape(693, 128)  # x,y
        self.data2 = self.V2.reshape(693, 10)
        self.data = [self.data1, self.data2]
        self.labels = self.y
class animal(BaseDataset):
    def load_data(self):
        self.y = scipy.io.loadmat(self.path + 'Animal.mat')['Y'].astype(np.int32).reshape(11673, )
        self.V1 = scipy.io.loadmat(self.path + 'Animal.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'Animal.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'Animal.mat')['X'][0][2].astype(np.float32)
        self.V4 = scipy.io.loadmat(self.path + 'Animal.mat')['X'][0][3].astype(np.float32)
        self.data1 = self.V1.reshape(11673, 2689)  # x,y
        self.data2 = self.V2.reshape(11673, 2000)
        self.data3 = self.V3.reshape(11673, 2001)
        self.data4 = self.V4.reshape(11673, 2000)
        self.data = [self.data1, self.data2, self.data3, self.data4]
        self.labels = self.y
class stl10(BaseDataset):
    def load_data(self):
        self.y  = scipy.io.loadmat(self.path + 'stl10_fea_v7.mat')['Y'].astype(np.int32).reshape(13000, )
        self.V1 = scipy.io.loadmat(self.path + 'stl10_fea_v7.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'stl10_fea_v7.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'stl10_fea_v7.mat')['X'][0][2].astype(np.float32)
        self.data1 = self.V1.reshape(13000, 1024)  # x,y
        self.data2 = self.V2.reshape(13000, 512)
        self.data3 = self.V3.reshape(13000, 2048)

        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y
class Mnist(BaseDataset):
    def load_data(self):
        self.y  = scipy.io.loadmat(self.path + 'MNIST.mat')['Y'].astype(np.int32).reshape(60000, )
        self.V1 = scipy.io.loadmat(self.path + 'MNIST.mat')['X'][0][0].astype(np.float32)
        self.V2 = scipy.io.loadmat(self.path + 'MNIST.mat')['X'][0][1].astype(np.float32)
        self.V3 = scipy.io.loadmat(self.path + 'MNIST.mat')['X'][0][2].astype(np.float32)
        self.data1 = self.V1.reshape(60000, 342)  # x,y
        self.data2 = self.V2.reshape(60000, 1024)
        self.data3 = self.V3.reshape(60000, 64)

        self.data = [self.data1, self.data2, self.data3]
        self.labels = self.y

def load_data(dataset):
    if dataset == "Caltech101_20":
        dataset = Caltech101_20('./datasets/')
        dims = [48, 40, 254, 1984, 512, 928]
        view = 6
        data_size = 2386
        class_num = 20
    elif dataset == "Caltech101_20all":
        dataset = Caltech101_20all('./datasets/')
        dims = [48, 40, 254, 1984, 512, 928]
        view = 6
        data_size = 9144
        class_num = 102
    elif dataset == "Caltech101_2":
        dataset = Caltech101_2('./datasets/')
        dims = [1984, 512]
        view = 2
        data_size = 2386
        class_num = 20
    elif dataset == "Scene_15":
        dataset = Scene_15('./datasets/')
        dims = [20, 59, 40]
        view = 3
        data_size = 4485
        class_num = 15
    elif dataset == "NGs":
        dataset = NGs('./datasets/')
        dims = [2000, 2000, 2000]
        view = 3
        data_size = 500
        class_num = 5
    elif dataset == "MSRC-v1":
        dataset = MSRCv1('./datasets/')
        dims = [24, 512, 256, 254]
        view = 4
        data_size = 210
        class_num = 7
    elif dataset == "MSRCv1":
        dataset = MSRC_v1('./datasets/')
        dims = [24,576, 512, 256, 254]
        view = 5
        data_size = 210
        class_num = 7
    elif dataset == "Reuters":
        dataset = Reuters('./datasets/')
        dims = [2000, 2000, 2000, 2000, 2000]
        view = 5
        data_size = 1200
        class_num = 6
    elif dataset == "ORL":
        dataset = ORL('./datasets/')
        dims = [4096, 3304, 6750]
        view = 3
        data_size = 400
        class_num = 40
    elif dataset == "BBCSport":
        dataset = BBCSport('./datasets/')
        dims = [3183, 3203]
        view = 2
        data_size = 544
        class_num = 5
    elif dataset == "HW-2":
        dataset = HW2('./datasets/')
        dims = [784, 256]
        view = 2
        data_size = 2000
        class_num = 10
    elif dataset == "HW-all":
        dataset = HWall('./datasets/')
        dims = [216, 76, 64, 6, 240, 47]
        view = 6
        data_size = 2000
        class_num = 10
    elif dataset == "NUS-WIDE":
        dataset = NUS_WIDE('./datasets/')
        dims = [65, 226, 145, 74, 129]
        view = 5
        data_size = 2000
        class_num = 31
    elif dataset == "ALOI_100_2":
        dataset = ALOI_100_2('./datasets/')
        dims = [77, 64]
        view = 2
        data_size = 10800
        class_num = 100
    elif dataset == "ALOI_100":
        dataset = ALOI_100('./datasets/')
        dims = [77,13,64,125]
        view = 4
        data_size = 10800
        class_num = 100
    elif dataset == "Landuse-21":
        dataset = Landuse_21('./datasets/')
        dims = [20, 59, 59]
        view = 3
        data_size = 2100
        class_num = 21
    elif dataset == "Yale_3":
        dataset = Yale_3('./datasets/')
        dims = [4096,3304,6750]
        view = 3
        data_size = 165
        class_num = 15
    elif dataset == "wikipedia":
        dataset = wikipedia('./datasets/')
        dims = [128,10]
        view = 2
        data_size = 693
        class_num = 10
    elif dataset == "animal":
        dataset = animal('./datasets/')
        dims = [2689,2000,2001,2000]
        view = 4
        data_size = 11673
        class_num = 20
    elif dataset == "stl10":
        dataset = stl10('./datasets/')
        dims = [1024,512,2048]
        view = 3
        data_size = 13000
        class_num = 10

    elif dataset == "mnist":
        dataset = Mnist('./datasets/')
        dims = [342,1024,64]
        view = 3
        data_size = 60000
        class_num = 10


    else:
        raise NotImplementedError(f"Dataset {dataset} is not supported")
    return dataset, dims, view, data_size, class_num




















