from torchvision import datasets, transforms
from PIL import ImageFile



class iDIGIT10():
    def __init__(self, task_id):
        # 1. USPS
        usps_train_trsf = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=63 / 255),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        usps_test_trsf = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        # 2. SVHN
        svhn_train_trsf = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=63 / 255),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4377, 0.4438, 0.4728),
                                 std=(0.1980, 0.2010, 0.1970))
        ])
        svhn_test_trsf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4377, 0.4438, 0.4728),
                                 std=(0.1980, 0.2010, 0.1970))
        ])

        # 3. EMNISTS
        emnist_train_trsf = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=63 / 255),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5))
        ])
        emnist_test_trsf = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),  # → [0,1] Tensor，[3,32,32]
            transforms.Normalize(
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5))
        ])

        # 4. MNIST
        mnist_train_trsf = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=63 / 255),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        mnist_test_trsf = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        match task_id:
            case 0:
                train_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/USPS/train', transform=usps_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/USPS/test',
                                                    transform=usps_test_trsf)
            case 1:
                train_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/SVHN/train', transform=svhn_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/SVHN/test',
                                                    transform=svhn_test_trsf)
            case 2:
                train_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/EMNIST/train', transform=emnist_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/EMNIST/test',
                                                    transform=emnist_test_trsf)
            case 3:
                train_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/MNIST/train', transform=mnist_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/DIGIT10/MNIST/test',
                                                    transform=mnist_test_trsf)

        self.num_classes = 10
        self.train_data = train_dataset
        self.test = test_dataset


    def download_data(self):
        return self.train_data, self.test, self.num_classes



class iPACS():
    def __init__(self, task_id):
        pacs_train_trsf = transforms.Compose([
                    transforms.Resize(256),                # Resize shorter side to 256
                    transforms.RandomResizedCrop(224),     # Random crop to 224x224
                    transforms.RandomHorizontalFlip(p=0.5),# Random flip
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])  # ImageNet stats
                    ])
        pacs_test_trsf = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                    ])

        match task_id:


            case 0:
                train_dataset = datasets.ImageFolder(
                    root='./data/pacs/train/sketch', transform=pacs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/pacs/test/sketch',
                                                    transform=pacs_test_trsf)
            case 1:
                train_dataset = datasets.ImageFolder(
                    root='./data/pacs/train/cartoon', transform=pacs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/pacs/test/cartoon',
                                                    transform=pacs_test_trsf)
            case 2:
                train_dataset = datasets.ImageFolder(
                    root='./data/pacs/train/photo', transform=pacs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/pacs/test/photo',
                                                    transform=pacs_test_trsf)
            case 3:
                train_dataset = datasets.ImageFolder(
                    root='./data/pacs/train/art_painting', transform=pacs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/pacs/test/art_painting',
                                                    transform=pacs_test_trsf)



        self.num_classes = 7
        self.train_data = train_dataset
        self.test = test_dataset


    def download_data(self):
        return self.train_data, self.test, self.num_classes

class iVLCS():
    def __init__(self, task_id):
        # Train transform
        vlcs_train_trsf = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize to match standard input
            transforms.RandomHorizontalFlip(),  # Data augmentation
            transforms.ToTensor(),  # Convert to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean
                                 std=[0.229, 0.224, 0.225])  # ImageNet std
        ])

        # Test transform
        vlcs_test_trsf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        ImageFile.LOAD_TRUNCATED_IMAGES = True

        match task_id:
            case 0:
                train_dataset = datasets.ImageFolder(
                    root='./data/vlcs/train/Caltech101', transform=vlcs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/vlcs/test/Caltech101', transform=vlcs_test_trsf)
            case 1:
                train_dataset = datasets.ImageFolder(
                    root='./data/vlcs/train/LabelMe', transform=vlcs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/vlcs/test/LabelMe', transform=vlcs_test_trsf)
            case 2:
                train_dataset = datasets.ImageFolder(
                    root='./data/vlcs/train/SUN09', transform=vlcs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/vlcs/test/SUN09', transform=vlcs_test_trsf)
            case 3:
                train_dataset = datasets.ImageFolder(
                    root='./data/vlcs/train/VOC2007', transform=vlcs_train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/vlcs/test/VOC2007', transform=vlcs_test_trsf)

        self.num_classes = 5
        self.train_data = train_dataset
        self.test = test_dataset


    def download_data(self):
        return self.train_data, self.test, self.num_classes

class iDN4IL():
    def __init__(self, task_id):
        # Train transform
        train_trsf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])

        test_trsf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])

        ImageFile.LOAD_TRUNCATED_IMAGES = True

        match task_id:
            case 5:
                train_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/train/clipart', transform=train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/test/clipart', transform=test_trsf)
            case 1:
                train_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/train/infograph', transform=train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/test/infograph', transform=test_trsf)
            case 2:
                train_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/train/painting', transform=train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/test/painting', transform=test_trsf)
            case 3:
                train_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/train/quickdraw', transform=train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/test/quickdraw', transform=test_trsf)
            case 4:
                train_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/train/real', transform=train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/test/real', transform=test_trsf)
            case 0:
                train_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/train/sketch', transform=train_trsf)
                test_dataset = datasets.ImageFolder(
                    root='./data/littleDomainnet/test/sketch', transform=test_trsf)


        self.num_classes = 100
        self.train_data = train_dataset
        self.test = test_dataset


    def download_data(self):
        return self.train_data, self.test, self.num_classes