import os
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

from .Utility import SubsetRandomSampler, SubsetSampler

class PurchaseDataset(Dataset):
    def __init__(self, file_path, split='train', transform=None):
        """
        file_path: Path to the .npz file containing 'features' and 'labels'.
        transform: Optional transform to be applied on a sample.
        """
        data = np.load(file_path)
        self.features = data['features'].astype(np.float32)
        self.labels = data['labels'].astype(np.float32)
        n = len(self.features)
        split_idx = int(n * 0.8)
        if split == 'train':
            self.features = self.features[:split_idx]
            self.labels = self.labels[:split_idx]
        else:
            self.features = self.features[split_idx:]
            self.labels = self.labels[split_idx:]

        self.transform = transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        feature = self.features[index]
        label = self.labels[index]
        label = np.argmax(label).astype(np.int64)

        if self.transform:
            feature = self.transform(feature)

        return feature, label, index

def purchase100(batch_size, root='./data', valid_ratio=None, shuffle=True, train_subset=None, test_subset=None, augmentation=None,
                mislabel_ratio=0., mislabel_seed=0, class_subset_path=None, is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
                member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0):
    '''
    batch_size: batch size.
    root: path to the data folder.
    valid_ratio: the ratio of validation data from training set, None if no validation set.
    shuffle: whether to shuffle training data.
    train_subset: if not None, only use a subset of training data.
    '''
    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    root = os.path.join(root, 'purchase100.npz')
    trainset = PurchaseDataset(file_path=root, split='train')
    testset = PurchaseDataset(file_path=root, split='test')

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)
    
    # 划分验证集（仅针对训练集有效率 > 0）
    if valid_ratio is not None and valid_ratio > 0:
        split_pt = int(len(train_indices) * valid_ratio)
        valid_indices = train_indices[:split_pt]
        train_indices = train_indices[split_pt:]
        if shuffle:
            train_sampler = SubsetRandomSampler(train_indices)
            valid_sampler = SubsetSampler(valid_indices)
        else:
            train_sampler = SubsetSampler(train_indices)
            valid_sampler = SubsetSampler(valid_indices)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0, pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=valid_sampler, num_workers=0, pin_memory=True)
    else:
        if shuffle:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0, pin_memory=True)
        valid_loader = None

    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, sampler=test_sampler, num_workers=0, pin_memory=True)
    
    classes = np.unique(trainset.labels).tolist()
    
    return train_loader, valid_loader, test_loader, classes


if __name__ == '__main__':
    train_loader, valid_loader, test_loader, classes = purchase100(batch_size=32, root='../data')
    for features, labels, indices in train_loader:
        print('Features shape:', features.shape, 'Labels shape:', labels.shape)
        break
    for features, labels, indices in test_loader:
        print('Test Features shape:', features.shape, 'Test Labels shape:', labels.shape)
        break