import json
import os
import pickle
from argparse import Namespace
from pathlib import Path
from typing import Dict, List, Optional, Type

import numpy as np
import pandas as pd
import torch
import torchvision
from omegaconf import DictConfig
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.io.image import ImageReadMode, read_image
import ssl
import torch.nn.functional as F
from PIL import Image


class BaseDataset(Dataset):
    def __init__(
        self,
        data: torch.Tensor,
        targets: torch.Tensor,
        classes: List[int],
        train_data_transform: Optional[transforms.Compose] = None,
        train_target_transform: Optional[transforms.Compose] = None,
        test_data_transform: Optional[transforms.Compose] = None,
        test_target_transform: Optional[transforms.Compose] = None,
    ) -> None:
        self.data = data
        self.targets = targets
        self.classes = classes
        self.train_data_transform = train_data_transform
        self.train_target_transform = train_target_transform
        self.test_data_transform = test_data_transform
        self.test_target_transform = test_target_transform
        self.data_transform = self.train_data_transform
        self.target_transform = self.train_target_transform

        # rescale data to fit in [0, 1.0] if needed
        # self._rescale_data()

    def _rescale_data(self):
        max_val = self.data.max()
        if max_val > 1.0:
            self.data /= 255.0

    def __getitem__(self, index):
        data, targets = self.data[index], self.targets[index]
        if self.data_transform is not None:
            data = self.data_transform(data)
        if self.target_transform is not None:
            targets = self.target_transform(targets)

        return data, targets

    def train(self):
        self.data_transform = self.train_data_transform
        self.target_transform = self.train_target_transform

    def eval(self):
        self.data_transform = self.test_data_transform
        self.target_transform = self.test_target_transform

    def __len__(self):
        return len(self.targets)


class FEMNIST(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ) -> None:
        if not isinstance(root, Path):
            root = Path(root)
        if not os.path.isfile(root / "data.npy") or not os.path.isfile(
            root / "targets.npy"
        ):
            raise RuntimeError(
                "Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first."
            )

        data = np.load(root / "data.npy")
        targets = np.load(root / "targets.npy")

        super().__init__(
            data=torch.from_numpy(data).float().reshape(-1, 1, 28, 28),
            targets=torch.from_numpy(targets).long(),
            classes=list(range(62)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class Synthetic(BaseDataset):
    def __init__(self, root, *args, **kwargs) -> None:
        if not isinstance(root, Path):
            root = Path(root)
        if not os.path.isfile(root / "data.npy") or not os.path.isfile(
            root / "targets.npy"
        ):
            raise RuntimeError(
                "Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first."
            )

        data = np.load(root / "data.npy")
        targets = np.load(root / "targets.npy")

        super().__init__(
            data=torch.from_numpy(data).float(),
            targets=torch.from_numpy(targets).long(),
            classes=sorted(np.unique(targets).tolist()),
        )


class CelebA(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ) -> None:
        if not isinstance(root, Path):
            root = Path(root)
        if not os.path.isfile(root / "data.npy") or not os.path.isfile(
            root / "targets.npy"
        ):
            raise RuntimeError(
                "Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first."
            )

        data = np.load(root / "data.npy")
        targets = np.load(root / "targets.npy")

        super().__init__(
            data=torch.from_numpy(data).permute([0, -1, 1, 2]).float(),
            targets=torch.from_numpy(targets).long(),
            classes=[0, 1],
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class MedMNIST(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)

        super().__init__(
            data=torch.Tensor(np.load(root / "raw" / "xdata.npy")).float().unsqueeze(1),
            targets=torch.Tensor(np.load(root / "raw" / "ydata.npy")).long().squeeze(),
            classes=list(range(11)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class COVID19(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
        super().__init__(
            data=torch.Tensor(np.load(root / "raw" / "xdata.npy"))
            .permute([0, -1, 1, 2])
            .float(),
            targets=torch.Tensor(np.load(root / "raw" / "ydata.npy")).long().squeeze(),
            classes=[0, 1, 2, 3],
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class USPS(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
        train_part = torchvision.datasets.USPS(root / "raw", True, download=True)
        test_part = torchvision.datasets.USPS(root / "raw", False, download=True)
        train_data = torch.Tensor(train_part.data).float().unsqueeze(1)
        test_data = torch.Tensor(test_part.data).float().unsqueeze(1)
        train_targets = torch.Tensor(train_part.targets).long()
        test_targets = torch.Tensor(test_part.targets).long()

        super().__init__(
            data=torch.cat([train_data, test_data]),
            targets=torch.cat([train_targets, test_targets]),
            classes=list(range(10)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class SVHN(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
        train_part = torchvision.datasets.SVHN(root / "raw", "train", download=True)
        test_part = torchvision.datasets.SVHN(root / "raw", "test", download=True)
        train_data = torch.Tensor(train_part.data).float()
        test_data = torch.Tensor(test_part.data).float()
        train_targets = torch.Tensor(train_part.labels).long()
        test_targets = torch.Tensor(test_part.labels).long()

        super().__init__(
            data=torch.cat([train_data, test_data]),
            targets=torch.cat([train_targets, test_targets]),
            classes=list(range(10)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class MNIST(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        train_part = torchvision.datasets.MNIST(root, True, download=True)
        test_part = torchvision.datasets.MNIST(root, False)
        train_data = torch.Tensor(train_part.data).float().unsqueeze(1)
        test_data = torch.Tensor(test_part.data).float().unsqueeze(1)
        train_targets = torch.Tensor(train_part.targets).long().squeeze()
        test_targets = torch.Tensor(test_part.targets).long().squeeze()

        super().__init__(
            data=torch.cat([train_data, test_data]),
            targets=torch.cat([train_targets, test_targets]),
            classes=list(range(10)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class FashionMNIST(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        train_part = torchvision.datasets.FashionMNIST(root, True, download=True)
        test_part = torchvision.datasets.FashionMNIST(root, False, download=True)
        train_data = torch.Tensor(train_part.data).float().unsqueeze(1)
        test_data = torch.Tensor(test_part.data).float().unsqueeze(1)
        train_targets = torch.Tensor(train_part.targets).long().squeeze()
        test_targets = torch.Tensor(test_part.targets).long().squeeze()

        super().__init__(
            data=torch.cat([train_data, test_data]),
            targets=torch.cat([train_targets, test_targets]),
            classes=list(range(10)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class EMNIST(BaseDataset):
    def __init__(
        self,
        root,
        args,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        split = None
        if isinstance(args, Namespace):
            split = args.emnist_split
        elif isinstance(args, dict):
            split = args["emnist_split"]
        elif isinstance(args, DictConfig):
            split = args.emnist_split
        train_part = torchvision.datasets.EMNIST(
            root, split=split, train=True, download=True
        )
        test_part = torchvision.datasets.EMNIST(
            root, split=split, train=False, download=False
        )
        train_data = torch.Tensor(train_part.data).float().unsqueeze(1)
        test_data = torch.Tensor(test_part.data).float().unsqueeze(1)
        train_targets = torch.Tensor(train_part.targets).long().squeeze()
        test_targets = torch.Tensor(test_part.targets).long().squeeze()

        super().__init__(
            data=torch.cat([train_data, test_data]),
            targets=torch.cat([train_targets, test_targets]),
            classes=list(range(len(train_part.classes))),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class CIFAR10(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        train_part = torchvision.datasets.CIFAR10(root, True, download=True)
        test_part = torchvision.datasets.CIFAR10(root, False, download=True)
        train_data = torch.Tensor(train_part.data).permute([0, -1, 1, 2]).float()
        test_data = torch.Tensor(test_part.data).permute([0, -1, 1, 2]).float()
        train_targets = torch.Tensor(train_part.targets).long().squeeze()
        test_targets = torch.Tensor(test_part.targets).long().squeeze()

        super().__init__(
            data=torch.cat([train_data, test_data]),
            targets=torch.cat([train_targets, test_targets]),
            classes=list(range(10)),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class CIFAR100(BaseDataset):
    def __init__(
        self,
        root,
        args,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        train_part = torchvision.datasets.CIFAR100(root, True, download=True)
        test_part = torchvision.datasets.CIFAR100(root, False, download=True)
        train_data = torch.Tensor(train_part.data).permute([0, -1, 1, 2]).float()
        test_data = torch.Tensor(test_part.data).permute([0, -1, 1, 2]).float()
        train_targets = torch.Tensor(train_part.targets).long().squeeze()
        test_targets = torch.Tensor(test_part.targets).long().squeeze()
        data = torch.cat([train_data, test_data])
        targets = torch.cat([train_targets, test_targets])
        classes = list(range(100))

        super_class = False
        if isinstance(args, (Namespace, DictConfig)):
            super_class = args.super_class
        elif isinstance(args, dict):
            super_class = args["super_class"]

        if super_class:
            # super_class: [sub_classes]
            CIFAR100_SUPER_CLASS = {
                0: ["beaver", "dolphin", "otter", "seal", "whale"],
                1: ["aquarium_fish", "flatfish", "ray", "shark", "trout"],
                2: ["orchid", "poppy", "rose", "sunflower", "tulip"],
                3: ["bottle", "bowl", "can", "cup", "plate"],
                4: ["apple", "mushroom", "orange", "pear", "sweet_pepper"],
                5: ["clock", "keyboard", "lamp", "telephone", "television"],
                6: ["bed", "chair", "couch", "table", "wardrobe"],
                7: ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
                8: ["bear", "leopard", "lion", "tiger", "wolf"],
                9: ["cloud", "forest", "mountain", "plain", "sea"],
                10: ["bridge", "castle", "house", "road", "skyscraper"],
                11: ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
                12: ["fox", "porcupine", "possum", "raccoon", "skunk"],
                13: ["crab", "lobster", "snail", "spider", "worm"],
                14: ["baby", "boy", "girl", "man", "woman"],
                15: ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
                16: ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
                17: ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"],
                18: ["bicycle", "bus", "motorcycle", "pickup_truck", "train"],
                19: ["lawn_mower", "rocket", "streetcar", "tank", "tractor"],
            }
            mapping = {}
            for super_cls, sub_cls in CIFAR100_SUPER_CLASS.items():
                for cls in sub_cls:
                    mapping[cls] = super_cls
            new_targets = []
            for cls in targets:
                new_targets.append(mapping[train_part.classes[cls]])
            targets = torch.tensor(new_targets, dtype=torch.long)
            classes = list(range(20))

        super().__init__(
            data=data,
            targets=targets,
            classes=classes,
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class TinyImagenet(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
        if not os.path.isdir(root / "raw"):
            raise RuntimeError(
                "Using `data/download/tiny_imagenet.sh` to download the dataset first."
            )
        classes = pd.read_table(
            root / "raw/wnids.txt", sep="\t", engine="python", header=None
        )[0].tolist()

        if not os.path.isfile(root / "data.pt") or not os.path.isfile(
            root / "targets.pt"
        ):
            mapping = dict(zip(classes, list(range(len(classes)))))
            data = []
            targets = []
            for cls in os.listdir(root / "raw" / "train"):
                for img_name in os.listdir(root / "raw" / "train" / cls / "images"):
                    img = read_image(
                        str(root / "raw" / "train" / cls / "images" / img_name),
                        mode=ImageReadMode.RGB,
                    ).float()
                    data.append(img)
                    targets.append(mapping[cls])

            table = pd.read_table(
                root / "raw/val/val_annotations.txt",
                sep="\t",
                engine="python",
                header=None,
            )
            test_classes = dict(zip(table[0].tolist(), table[1].tolist()))
            for img_name in os.listdir(root / "raw" / "val" / "images"):
                img = read_image(
                    str(root / "raw" / "val" / "images" / img_name),
                    mode=ImageReadMode.RGB,
                ).float()
                data.append(img)
                targets.append(mapping[test_classes[img_name]])
            torch.save(torch.stack(data), root / "data.pt")
            torch.save(torch.tensor(targets, dtype=torch.long), root / "targets.pt")

        super().__init__(
            data=torch.load(root / "data.pt"),
            targets=torch.load(root / "targets.pt"),
            classes=list(range(len(classes))),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class CINIC10(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
        if not os.path.isdir(root / "raw"):
            raise RuntimeError(
                "Using `data/download/tiny_imagenet.sh` to download the dataset first."
            )
        classes = [
            "airplane",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck",
        ]
        if not os.path.isfile(root / "data.pt") or not os.path.isfile(
            root / "targets.pt"
        ):
            data = []
            targets = []
            mapping = dict(zip(classes, range(10)))
            for folder in ["test", "train", "valid"]:
                for cls in os.listdir(Path(root) / "raw" / folder):
                    for img_name in os.listdir(root / "raw" / folder / cls):
                        img = read_image(
                            str(root / "raw" / folder / cls / img_name),
                            mode=ImageReadMode.RGB,
                        ).float()
                        data.append(img)
                        targets.append(mapping[cls])
            torch.save(torch.stack(data), root / "data.pt")
            torch.save(torch.tensor(targets, dtype=torch.long), root / "targets.pt")

        super().__init__(
            data=torch.load(root / "data.pt"),
            targets=torch.load(root / "targets.pt"),
            classes=list(range(len(classes))),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )


class DomainNet(BaseDataset):
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ) -> None:
        if not isinstance(root, Path):
            root = Path(root)
        if not os.path.isdir(root / "raw"):
            raise RuntimeError(
                "Using `data/download/domain.sh` to download the dataset first."
            )
        targets_path = root / "targets.pt"
        metadata_path = root / "metadata.json"
        filename_list_path = root / "filename_list.pkl"
        if not (
            os.path.isfile(targets_path)
            and os.path.isfile(metadata_path)
            and os.path.isfile(filename_list_path)
        ):
            raise RuntimeError(
                "Run data/domain/preprocess.py to preprocess DomainNet first."
            )

        with open(metadata_path, "r") as f:
            metadata = json.load(f)
        with open(filename_list_path, "rb") as f:
            self.filename_list = pickle.load(f)

        self.pre_transform = transforms.Compose(
            [transforms.Resize(metadata["image_size"]), transforms.ToTensor()]
        )
        super().__init__(
            data=torch.empty(1, 1, 1, 1),  # dummy data
            targets=torch.load(targets_path),
            classes=list(range(len(metadata["classes"]))),
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )

    def __getitem__(self, index):
        data = self.pre_transform(Image.open(self.filename_list[index]).convert("RGB"))
        targets = self.targets[index]
        if self.data_transform is not None:
            data = self.data_transform(data)
        if self.target_transform is not None:
            targets = self.target_transform(targets)
        return data, targets

class MultiDomainDigits(BaseDataset):
    """多领域数字数据集，包含MNIST、USPS、SVHN等不同领域的数字识别数据集"""
    
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
            
        # 保存 root 路径作为实例属性
        self.root = root
            
        # 安全证书设置，避免下载问题
        ssl._create_default_https_context = ssl._create_unverified_context
        
        self.domains =  ["mnist", "usps", "svhn", "syn"]#["usps"]#
        if args and hasattr(args, "domains") and args.domains:
            self.domains = args.domains
            
        self.domain_indices = {}  # 每个域的所有数据索引
        self.domain_train_indices = {}  # 每个域的训练数据索引
        self.domain_test_indices = {}  # 每个域的测试数据索引
        self.domain_data = {}
        self.domain_targets = {}
        
        # 初始化数据和目标属性
        self.data = torch.empty(0)
        self.targets = torch.empty(0, dtype=torch.long)
        
        # 加载每个域的数据集
        for domain in self.domains:
            self._load_domain_dataset(domain)
        
            
        # 调用父类初始化函数
        super().__init__(
            data=self.data,
            targets=self.targets,
            classes=list(range(10)),  # 假设所有域都是10分类数字识别
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )
        
    def train(self):
        self.data_transform = transforms.Compose(
        [
            # transforms.Resize((32, 32)),
            transforms.RandomCrop(32, padding=4),
            # transforms.ToTensor(),
            # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.target_transform = self.train_target_transform

    def eval(self):
        self.data_transform = transforms.Compose(
            [
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ]
        )
        self.target_transform = self.test_target_transform
        
    ## 设置dataset的transform
    def set_transform(self, domain):
        
        print(f"设置数据集的transform: {domain}")
        """设置数据集的transform"""
        sin_chan_nor_transform = transforms.Compose(
        [
            # transforms.Resize((32, 32)),
            transforms.RandomCrop(32, padding=4),
            # transforms.ToTensor(),
            # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        nor_transform = transforms.Compose(
            [
                # transforms.Resize((32, 32)),
                transforms.RandomCrop(32, padding=4),
                # transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        
        test_transform = transforms.Compose(
            [
                # transforms.Resize((32, 32)),
                # transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ]
        )

        sin_chan_test_transform = transforms.Compose(
            [
                # transforms.Resize((32, 32)),
                # transforms.ToTensor(),
                # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ]
        )
        if domain in ["mnist", "usps"]:
            self.train_data_transform = sin_chan_nor_transform
            self.test_data_transform = sin_chan_test_transform
        else:
            self.train_data_transform = nor_transform
            self.test_data_transform = test_transform
        
        
            
    def _load_domain_dataset(self, domain):
        """加载指定域的数据集，区分训练集和测试集"""
        # 根据领域选择不同的数据集加载方式
        
        import torchvision.transforms as transforms
        base_transform = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor()
            ]
        )
        sin_base_transform = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1))
            ]
        )
        if domain == "mnist":
            from torchvision.datasets import MNIST
            train_dataset = MNIST(self.root, train=True, download=True,transform=sin_base_transform)
            test_dataset = MNIST(self.root, train=False, download=True,transform=sin_base_transform)
            

            train_data = torch.stack([img for img, _ in train_dataset], dim=0)  
            test_data = torch.stack([img for img, _ in test_dataset], dim=0)
            train_targets = train_dataset.targets
            test_targets = test_dataset.targets
            
        elif domain == "usps":
            from torchvision.datasets import USPS
            train_dataset = USPS(self.root, train=True, download=True,transform=sin_base_transform)
            test_dataset = USPS(self.root, train=False, download=True,transform=sin_base_transform)
            
            train_data = torch.stack([img for img, _ in train_dataset], dim=0)  
            test_data = torch.stack([img for img, _ in test_dataset], dim=0)
                
            train_targets = torch.tensor(train_dataset.targets)
            test_targets = torch.tensor(test_dataset.targets)
            
        elif domain == "svhn":
            from torchvision.datasets import SVHN
            train_dataset = SVHN(self.root, split='train', download=True,transform=base_transform)
            test_dataset = SVHN(self.root, split='test', download=True,transform=base_transform)
            
            
            train_data = torch.stack([img for img, _ in train_dataset], dim=0)  
            test_data = torch.stack([img for img, _ in test_dataset], dim=0)
                
            train_targets = torch.tensor(train_dataset.labels)
            test_targets = torch.tensor(test_dataset.labels)
            
        elif domain == "syn":
            # 使用ImageFolder加载SYN数据集
        
            syn_train_path = self.root / "syn" / "imgs_train"
            syn_test_path = self.root / "syn" / "imgs_valid"
            
            if not syn_train_path.exists() or not syn_test_path.exists():
                print(f"警告: SYN数据集路径不存在 - {syn_train_path} 或 {syn_test_path}")
                self.domain_indices[domain] = []
                return
                
            try:
                from torchvision.datasets import ImageFolder
                import math
                train_dataset = ImageFolder(syn_train_path,transform= base_transform)#, transform=nor_transform)
                test_dataset = ImageFolder(syn_test_path,transform= base_transform)#, transform=test_transform)
                    
                train_data = torch.empty(len(train_dataset), 3, 32, 32)
                test_data = torch.empty(len(test_dataset), 3, 32, 32)
                
                train_targets = []
                test_targets = []
                
                for i, (img, target) in enumerate(train_dataset):
                    train_data[i] = img
                    train_targets.append(target)
                
                for i, (img, target) in enumerate(test_dataset):
                    test_data[i] = img
                    test_targets.append(target)
                
                train_targets = torch.tensor(train_targets)
                test_targets = torch.tensor(test_targets)
                
            except Exception as e:
                print(f"加载SYN数据集时出错: {e}")
                self.domain_indices[domain] = []
                return
        
        print(f"Domain {domain} data shape - train: {train_data.shape}, test: {test_data.shape}")
        
        # 记录数据在全局索引中的位置
        start_idx = len(self.data) if len(self.data) > 0 else 0
        
        # 先添加训练数据
        train_size = len(train_data)
        train_indices = list(range(start_idx, start_idx + train_size))
        # print(start_idx,"-",start_idx + train_size)
        # print(f"Domain {domain} train indices: {len(train_indices)}")
        
        # 然后添加测试数据
        test_indices = list(range(start_idx + train_size, start_idx + train_size + len(test_data)))
        # print(start_idx + train_size,"-", start_idx + train_size + len(test_data))
        
        # 保存域数据和索引
        self.domain_data[domain] = torch.cat([train_data, test_data])
        self.domain_targets[domain] = torch.cat([train_targets, test_targets])
        self.domain_indices[domain] = train_indices + test_indices
        self.domain_train_indices[domain] = train_indices
        self.domain_test_indices[domain] = test_indices
        
        # 合并当前域的数据到全局数据集
        if len(self.data) == 0:  # 如果是第一个域的数据
            self.data = self.domain_data[domain]
            self.targets = self.domain_targets[domain]
        else:  # 追加当前域的数据到已有数据
            try:
                self.data = torch.cat([self.data, self.domain_data[domain]])
                self.targets = torch.cat([self.targets, self.domain_targets[domain]])
                print(f"Added domain {domain}, updated data shape: {self.data.shape}")
            except RuntimeError as e:
                print(f"Error adding domain {domain}: {e}")
                print(f"Current data shape: {self.data.shape}, domain data shape: {self.domain_data[domain].shape}")

        
    

    @staticmethod
    def get_normalization_transform():
        transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        return transform
    
    def get_domain_split_indices(self):
        """返回每个域的训练和测试索引"""
        return {
            'train': self.domain_train_indices,
            'test': self.domain_test_indices
        }




class MultiDomainOffice(BaseDataset):
    """多领域办公数据集，包含caltech、amazon、webcam、dslr不同领域的办公数据集"""
    
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
            
        # 保存 root 路径作为实例属性
        self.root = root
            
        # 所有可用的域
        self.domains = ["caltech", "amazon", "webcam", "dslr"]
        if args and hasattr(args, "domains") and args.domains:
            self.domains = args.domains
            
        self.domain_indices = {}  # 每个域的所有数据索引
        self.domain_train_indices = {}  # 每个域的训练数据索引
        self.domain_test_indices = {}  # 每个域的测试数据索引
        self.domain_data = {}
        self.domain_targets = {}
        
        # 初始化数据和目标属性
        self.data = torch.empty(0)
        self.targets = torch.empty(0, dtype=torch.long)
        
        # 加载每个域的数据集
        for domain in self.domains:
            self._load_domain_dataset(domain)
            
        # 调用父类初始化函数
        super().__init__(
            data=self.data,
            targets=self.targets,
            classes=list(range(10)),  # Office数据集通常有10个类别
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )
    def train(self):
        self.data_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            # transforms.RandomCrop(192, padding=4),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.target_transform = self.train_target_transform

    def eval(self):
        self.data_transform = transforms.Compose([
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.target_transform = self.test_target_transform
        
    def set_transform(self, domain):
        """设置数据集的transform"""
        nor_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            # transforms.RandomCrop(192, padding=4),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        test_transform = transforms.Compose([
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        self.train_data_transform = nor_transform
        self.test_data_transform = test_transform
            
    def _load_domain_dataset(self, domain):
        """加载指定域的数据集，区分训练集和测试集"""
        # Office数据集路径
        domain_path = self.root /  domain
        
        if not domain_path.exists():
            print(f"警告: {domain}数据集路径不存在 - {domain_path}")
            self.domain_indices[domain] = []
            return
                
        try:
            from torchvision.datasets import ImageFolder
            import math
            
            # 用于转换图像的基本变换
            base_transform = transforms.Compose([
                transforms.Resize((192, 192)),
                transforms.ToTensor()
            ])
            
            # 创建完整数据集
            dataset = ImageFolder(domain_path, transform=base_transform)
            
            # 按照7:3比例划分训练集和测试集
            subset_train_num = 7
            subset_capacity = 10
            
            train_indices = []
            test_indices = []
            
            # 按类别进行分层抽样
            for i in range(len(dataset)):
                if i % subset_capacity <= subset_train_num:
                    train_indices.append(i)
                else:
                    test_indices.append(i)
            
            # 根据索引提取数据
            all_data = []
            all_targets = []
            
            for i in range(len(dataset)):
                img, target = dataset[i]
                all_data.append(img)
                all_targets.append(target)
                
            train_data = torch.stack([all_data[i] for i in train_indices])
            test_data = torch.stack([all_data[i] for i in test_indices])
            
            train_targets = torch.tensor([all_targets[i] for i in train_indices])
            test_targets = torch.tensor([all_targets[i] for i in test_indices])
                
        except Exception as e:
            print(f"加载{domain}数据集时出错: {e}")
            self.domain_indices[domain] = []
            return
        
        print(f"Domain {domain} data shape - train: {train_data.shape}, test: {test_data.shape}")
        
        # 记录数据在全局索引中的位置
        start_idx = len(self.data) if len(self.data) > 0 else 0
        
        # 先添加训练数据
        train_size = len(train_data)
        train_indices = list(range(start_idx, start_idx + train_size))
        
        # 然后添加测试数据
        test_indices = list(range(start_idx + train_size, start_idx + train_size + len(test_data)))
        
        # 保存域数据和索引
        self.domain_data[domain] = torch.cat([train_data, test_data])
        self.domain_targets[domain] = torch.cat([train_targets, test_targets])
        self.domain_indices[domain] = train_indices + test_indices
        self.domain_train_indices[domain] = train_indices
        self.domain_test_indices[domain] = test_indices
        
        
        
        # 合并当前域的数据到全局数据集
        if len(self.data) == 0:  # 如果是第一个域的数据
            self.data = self.domain_data[domain]
            self.targets = self.domain_targets[domain]
        else:  # 追加当前域的数据到已有数据
            try:
                self.data = torch.cat([self.data, self.domain_data[domain]])
                self.targets = torch.cat([self.targets, self.domain_targets[domain]])
                print(f"Added domain {domain}, updated data shape: {self.data.shape}")
            except RuntimeError as e:
                print(f"Error adding domain {domain}: {e}")
                print(f"Current data shape: {self.data.shape}, domain data shape: {self.domain_data[domain].shape}")

    @staticmethod
    def get_normalization_transform():
        transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        return transform
    
    def get_domain_split_indices(self):
        """返回每个域的训练和测试索引"""
        return {
            'train': self.domain_train_indices,
            'test': self.domain_test_indices
        }
        


class MultiDomainPacs(BaseDataset):
    """多领域PACS数据集，包含art_painting、cartoon、photo、sketch不同领域的数据集"""
    
    def __init__(
        self,
        root,
        args=None,
        test_data_transform=None,
        test_target_transform=None,
        train_data_transform=None,
        train_target_transform=None,
    ):
        if not isinstance(root, Path):
            root = Path(root)
            
        # 保存 root 路径作为实例属性
        self.root = root
            
        # 所有可用的域
        self.domains = ["art_painting", "cartoon", "photo", "sketch"]
        if args and hasattr(args, "domains") and args.domains:
            self.domains = args.domains
            
        # 类别映射 - 根据文件夹结构确定的7个类别
        self.class_names = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.class_names)}
            
        self.domain_indices = {}  # 每个域的所有数据索引
        self.domain_train_indices = {}  # 每个域的训练数据索引
        self.domain_test_indices = {}  # 每个域的测试数据索引
        self.domain_data = {}
        self.domain_targets = {}
        
        # 初始化数据和目标属性
        self.data = torch.empty(0)
        self.targets = torch.empty(0, dtype=torch.long)
        
        # 加载每个域的数据集
        for domain in self.domains:
            self._load_domain_dataset(domain)
            
        # 调用父类初始化函数
        super().__init__(
            data=self.data,
            targets=self.targets,
            classes=list(range(len(self.class_names))),  # PACS数据集有7个类别
            test_data_transform=test_data_transform,
            test_target_transform=test_target_transform,
            train_data_transform=train_data_transform,
            train_target_transform=train_target_transform,
        )
        
    def set_transform(self, domain):
        """设置数据集的transform"""
        nor_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            # transforms.RandomCrop(224, padding=4),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        test_transform = transforms.Compose([
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        self.train_data_transform = nor_transform
        self.test_data_transform = test_transform
         
    def train(self):
        self.data_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.target_transform = self.train_target_transform

    def eval(self):
        self.data_transform = transforms.Compose([
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.target_transform = self.test_target_transform
        
                    
    def _load_domain_dataset(self, domain):
        """加载指定域的数据集，区分训练集和测试集"""
        # PACS数据集路径
        domain_path = self.root / domain
        
        if not domain_path.exists():
            print(f"警告: {domain}数据集路径不存在 - {domain_path}")
            self.domain_indices[domain] = []
            return
                
        try:
            # 用于转换图像的基本变换
            base_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor()
            ])
            
            # 手动加载数据，按类别组织
            all_data = []
            all_targets = []
            
            # 遍历每个类别文件夹
            for class_name in self.class_names:
                class_path = domain_path / class_name
                if not class_path.exists():
                    print(f"警告: 类别路径不存在 - {class_path}")
                    continue
                    
                class_idx = self.class_to_idx[class_name]
                
                # 加载该类别下的所有图片
                image_extensions = ["*.jpg", "*.png", "*.jpge"]
                for extension in image_extensions:
                    for img_file in class_path.glob(extension):
                        try:
                            img = Image.open(img_file).convert("RGB")
                            img_tensor = base_transform(img)
                            all_data.append(img_tensor)
                            all_targets.append(class_idx)
                        except Exception as e:
                            print(f"加载图片失败 {img_file}: {e}")
                            continue
                # for img_file in class_path.glob("*.jpg"):
                #     try:
                #         img = Image.open(img_file).convert("RGB")
                #         img_tensor = base_transform(img)
                #         all_data.append(img_tensor)
                #         all_targets.append(class_idx)
                #     except Exception as e:
                #         print(f"加载图片失败 {img_file}: {e}")
                #         continue
            
            if not all_data:
                print(f"警告: {domain} 域没有加载到任何数据")
                self.domain_indices[domain] = []
                return
            
            # 转换为tensor
            all_data = torch.stack(all_data)
            all_targets = torch.tensor(all_targets, dtype=torch.long)
            
            # 按照8:2比例划分训练集和测试集，进行分层抽样
            train_data = []
            test_data = []
            train_targets = []
            test_targets = []
            
            # 对每个类别进行分层抽样
            for class_idx in range(len(self.class_names)):
                class_mask = all_targets == class_idx
                class_data = all_data[class_mask]
                class_targets = all_targets[class_mask]
                
                if len(class_data) == 0:
                    continue
                    
                # 计算训练集大小 (80%)
                train_size = int(0.8 * len(class_data))
                
                # 随机打乱并分割
                indices = torch.randperm(len(class_data))
                train_indices = indices[:train_size]
                test_indices = indices[train_size:]
                
                train_data.append(class_data[train_indices])
                test_data.append(class_data[test_indices])
                train_targets.append(class_targets[train_indices])
                test_targets.append(class_targets[test_indices])
            
            # 合并所有类别的数据
            if train_data:
                train_data = torch.cat(train_data, dim=0)
                train_targets = torch.cat(train_targets, dim=0)
            else:
                train_data = torch.empty(0, 3, 224, 224)
                train_targets = torch.empty(0, dtype=torch.long)
                
            if test_data:
                test_data = torch.cat(test_data, dim=0)
                test_targets = torch.cat(test_targets, dim=0)
            else:
                test_data = torch.empty(0, 3, 224, 224)
                test_targets = torch.empty(0, dtype=torch.long)
                
        except Exception as e:
            print(f"加载{domain}数据集时出错: {e}")
            self.domain_indices[domain] = []
            return
        
        print(f"Domain {domain} data shape - train: {train_data.shape}, test: {test_data.shape}")
        
        # 记录数据在全局索引中的位置
        start_idx = len(self.data) if len(self.data) > 0 else 0
        
        # 先添加训练数据
        train_size = len(train_data)
        train_indices = list(range(start_idx, start_idx + train_size))
        
        # 然后添加测试数据
        test_indices = list(range(start_idx + train_size, start_idx + train_size + len(test_data)))
        
        # 保存域数据和索引
        self.domain_data[domain] = torch.cat([train_data, test_data])
        self.domain_targets[domain] = torch.cat([train_targets, test_targets])
        self.domain_indices[domain] = train_indices + test_indices
        self.domain_train_indices[domain] = train_indices
        self.domain_test_indices[domain] = test_indices
        
        # 合并当前域的数据到全局数据集
        if len(self.data) == 0:  # 如果是第一个域的数据
            self.data = self.domain_data[domain]
            self.targets = self.domain_targets[domain]
        else:  # 追加当前域的数据到已有数据
            try:
                self.data = torch.cat([self.data, self.domain_data[domain]])
                self.targets = torch.cat([self.targets, self.domain_targets[domain]])
                print(f"Added domain {domain}, updated data shape: {self.data.shape}")
            except RuntimeError as e:
                print(f"Error adding domain {domain}: {e}")
                print(f"Current data shape: {self.data.shape}, domain data shape: {self.domain_data[domain].shape}")


    
    def get_domain_split_indices(self):
        """返回每个域的训练和测试索引"""
        return {
            'train': self.domain_train_indices,
            'test': self.domain_test_indices
        }




# 更新DATASETS字典，添加新的数据集类
DATASETS: Dict[str, Type[BaseDataset]] = {
    "cifar10": CIFAR10,
    "cifar100": CIFAR100,
    "mnist": MNIST,
    "emnist": EMNIST,
    "fmnist": FashionMNIST,
    "femnist": FEMNIST,
    "medmnistS": MedMNIST,
    "medmnistC": MedMNIST,
    "medmnistA": MedMNIST,
    "covid19": COVID19,
    "celeba": CelebA,
    "synthetic": Synthetic,
    "svhn": SVHN,
    "usps": USPS,
    "tiny_imagenet": TinyImagenet,
    "cinic10": CINIC10,
    "domain": DomainNet,
    "multi_domain_digits": MultiDomainDigits,
    "multi_domain_office": MultiDomainOffice,  
    "multi_domain_pacs": MultiDomainPacs, 
}
