import os
import os.path
import sys
import torch
import numpy as np
import pickle
# import h5py
import scipy
from scipy.io import loadmat
import torch.utils.data as data
from copy import deepcopy
from sklearn.model_selection import KFold
from torch.utils.data import Dataset
from global_var import REALWORLD_DATA_ROOT

def load_realworld(ds, batch_size, split_seed=42, device=None, partial_rate=None, partial_num=None, num_or_rate="rate", has_eval_train_loader=False, has_meta_valid=False):
    data_path = os.path.join(REALWORLD_DATA_ROOT, ds + '.mat')
    data_reader = RealwordDataLoader(data_path)
    full_dataset = RealWorldData(data_reader)
    full_data_size = len(full_dataset)
    test_size, valid_size = int(full_data_size * 0.1), int(full_data_size * 0.1)
    train_size = full_data_size - test_size - valid_size
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, valid_size, test_size], torch.Generator().manual_seed(split_seed))
    train_idx, valid_idx, test_idx = train_dataset.indices, valid_dataset.indices, test_dataset.indices
    train_dataset, valid_dataset, test_dataset = \
        My_Subset(full_dataset, train_idx, 'train'), My_Subset(full_dataset, valid_idx, 'valid'), My_Subset(full_dataset, test_idx, 'test')
    partial_matrix_train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                            batch_size=batch_size,
                                            shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                            batch_size=batch_size,
                                            shuffle=True)
    train_X, train_Y, train_p_Y = partial_matrix_train_loader.dataset.images, partial_matrix_train_loader.dataset.true_labels, partial_matrix_train_loader.dataset.given_label_matrix

    valid_X, valid_Y = valid_loader.dataset.images, valid_loader.dataset.true_labels
    test_X, test_Y = test_loader.dataset.images, test_loader.dataset.true_labels
    num_samples = train_X.shape[0]
    dim = train_X.shape[-1]
    K = train_p_Y.shape[-1]
    train_X = train_X.view((num_samples, -1))

    return_list = []
    return_list += [partial_matrix_train_loader, valid_loader, test_loader, dim, K]
    if has_eval_train_loader:
        eval_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                            batch_size=batch_size, 
                                                            shuffle=True, 
                                                            num_workers=8,
                                                            drop_last=True)
        return_list.append(eval_train_loader)
    if has_meta_valid:
        return_list.append(valid_X)
        return_list.append(valid_Y)

    return return_list


class RealwordDataLoader:
    def __init__(self, mat_path):
        self.data = loadmat(mat_path)
        print(self.data.keys())
        if "data" in self.data.keys():
            self.features, self.targets, self.partial_targets = self.data['data'], self.data['target'], self.data[
            'partial_target']
        else:
            self.features, self.targets, self.partial_targets = self.data['features'], self.data['logitlabels'], self.data[
                'p_labels']
        if self.features.shape[0] != self.targets.shape[0]:
            self.targets = self.targets.transpose()
            self.partial_targets = self.partial_targets.transpose()
        if type(self.targets) != np.ndarray:
            self.targets = self.targets.toarray()
            self.partial_targets = self.partial_targets.toarray()

        # normalize
        print(self.features.shape, self.targets.shape, self.partial_targets.shape)
        self.features = (self.features - self.features.mean(axis=0, keepdims=True)) / self.features.std(axis=0,
                                                                                                        keepdims=True)
        self.num_features, self.num_classes = self.features.shape[-1], self.targets.shape[-1]

    def get_data(self):
        def to_sum_one(x):
            return x / x.sum(axis=1, keepdims=True)

        def to_torch(x):
            return torch.from_numpy(x).to(torch.float32)

        self.final_labels = to_sum_one(self.targets)
        self.features, self.partial_targets, self.final_labels, self.true_labels = map(to_torch, (
            self.features, self.partial_targets, self.final_labels, self.targets
        ))
        self.true_labels = torch.argmax(self.true_labels, dim=1)
        return self.features, self.partial_targets, self.final_labels, self.true_labels


class RealWorldData(data.Dataset):
    def __init__(self, realword_dataloader):
        self.dataset = realword_dataloader.get_data()
        self.images, self.given_label_matrix, _, self.true_labels = self.dataset

    def __getitem__(self, index):
        each_image_o, each_label, each_true_label = self.images[index], self.given_label_matrix[index], self.true_labels[index]

        return each_image_o, each_label, each_true_label, index

    def __len__(self):
        return len(self.images)


class My_Subset(Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """

    def __init__(self, dataset, indices, settype):
        self.indices = indices
        self.settype = settype
        self.images, self.given_label_matrix, self.true_labels = dataset.images[indices, :], dataset.given_label_matrix[indices, :], dataset.true_labels[indices]

    def __getitem__(self, index):
        each_image_o, each_label, each_true_label = self.images[index], self.given_label_matrix[index], self.true_labels[index]
        each_image = [each_image_o, ]
        if self.settype == "train":
            return each_image, each_label, each_true_label, index
        elif self.settype == "valid":
            return each_image_o, each_true_label
        elif self.settype == "test":
            return each_image_o, each_true_label
    def __len__(self):
        return len(self.indices)