from .comlib import *
from .task import Task

class MnistTask(Task):
    def __init__(self):
        super().__init__()
        # self.normalize = transforms.Normalize((0.5,), (0.5,)) 
        self.input_shape=[1,28,28]
        self.input_range=[-1,1] #数字部分为1
        self.input_elenum=np.prod(self.input_shape)
        self.train_data_num=60000
        self.test_data_num=10000
        self.class_num=10


    def load_mnist_data(self,root="/home/yjf/FL/data", transform=None, target_transform=None):
        if transform is None:
            transform = transforms.Compose([
                transforms.ToTensor(),  # 将图像转换为Tensor
                transforms.Normalize((0.5,), (0.5,))  # 归一化
            ])
        train_dataset = datasets.MNIST(root=root, train=True, download=True, transform=transform, target_transform=target_transform)
        test_dataset = datasets.MNIST(root=root, train=False, download=True, transform=transform, target_transform=target_transform)
        return train_dataset,test_dataset
    

class CifarTask(Task):
    def __init__(self):
        super().__init__()
        # self.normalize = transforms.Normalize((0.5,), (0.5,)) 
        self.input_shape=[3,32,32]
        # self.input_range=[-1,1] #数字部分为1
        self.input_elenum=np.prod(self.input_shape)
        self.train_data_num=60000
        self.test_data_num=10000
        self.class_num=10


    def load_data(self,root="/home/yjf/FL/data", transform=None, target_transform=None):
        if transform is None:
            transform = transforms.Compose([
                transforms.ToTensor(),  # 将图像转换为Tensor
                transforms.Normalize((0.5,), (0.5,))  # 归一化
            ])
        train_dataset = datasets.CIFAR10(root=root, train=True, download=True, transform=transform, target_transform=target_transform)
        test_dataset = datasets.CIFAR10(root=root, train=False, download=True, transform=transform, target_transform=target_transform)
        return train_dataset,test_dataset