import torch
from torchvision import datasets, transforms
from utils.sampling import mnist_iid, mnist_iid_normal, fmnist_noniid_split, cifar_iid, mnist_noniid_split, minmax_dataset, fmnist_iid_normal, fmnist_noniid_normal
import numpy as np
import random
def load_data(args):
    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST(
            '../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST(
            '../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users, dataset_train_real = mnist_iid_normal(dataset_train, args.num_users, args.size)
        else:
            dict_users, dataset_train_real = mnist_noniid_split(dataset_train, args.num_users, args.size,  args.p)
    elif args.dataset == 'fmnist':
        trans_mnist = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.3476,), (0.3568,))])
        dataset_train = datasets.FashionMNIST(
            '../data/fmnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.FashionMNIST(
            '../data/fmnist/', train=False, download=True, transform=trans_mnist)

        # 确保 targets 和 data 的大小一致
        if len(dataset_train.targets) != len(dataset_train.data):
            dataset_train.targets = dataset_train.targets[:len(dataset_train.data)]
        if len(dataset_test.targets) != len(dataset_test.data):
            dataset_test.targets = dataset_test.targets[:len(dataset_test.data)]

        # sample users
        if args.iid:
            dict_users, dataset_train_real = fmnist_iid_normal(dataset_train, args.num_users)
        else:
            dict_users, dataset_train_real = fmnist_noniid_split(dataset_train, args.num_users, args.p)
            
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10(
            '../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10(
            '../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users, dataset_train_real = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    elif args.dataset == 'minmax_synthetic':
        dataset_train, dataset_test, dict_users, img_size, dataset_train_real = minmax_dataset(args)
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape
    return  dataset_train, dataset_test, dict_users, img_size, dataset_train_real


