import os

# Allow overriding the PACS dataset repository
PACS_REPO = os.environ.get("PACS_REPO", "flwrlabs/pacs")

from huggingface_hub import snapshot_download
import json
import pyarrow.parquet as pq

import pandas as pd
from torch.utils.data import Dataset,TensorDataset, DataLoader
from PIL import Image
import os
from models.vggmodule import vgg
from models.Resnet import ResNet18
import timm
from transformers import ViTForImageClassification
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import re

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, svhn_iid
from models.Nets import MLP, CNNMnist, CNNCifar, Lenet5, LeNet, DigitModel
from utils import data_utils
from utils.data_utils import OfficeDataset, DomainNetDataset, PACSDataset
from collections import defaultdict
from torch.utils.data import Subset

from torch.utils.data import DataLoader
from PIL import Image
from torchvision.datasets import ImageFolder

import pickle
import random
from torch.utils.data import Subset
import numpy as np
from sklearn.model_selection import train_test_split

from models.vit import vit_base_patch16, vit_small_patch16, vit_large_patch16, mae_vit_base, mae_vit_small, load_pretrained_mae, replace_classifier

# --- PACS dataset utilities ---
def _extract_pacs_images(pacs_dir):
    """Extract images from Hugging Face parquet files into domain folders."""
    parquet_path = os.path.join(pacs_dir, "data", "train-00000-of-00001.parquet")
    if not os.path.isfile(parquet_path):
        return

    print(f"Extracting PACS images from {parquet_path}...")
    pf = pq.ParquetFile(parquet_path)
    meta = json.loads(pf.metadata.metadata[b'huggingface'])
    label_names = meta["info"]["features"]["label"]["names"]

    for batch in pf.iter_batches():
        imgs = batch.column("image")
        domains = batch.column("domain")
        labels = batch.column("label")
        for img_struct, domain, label in zip(imgs, domains, labels):
            domain_name = domain.as_py()
            cls = label_names[label.as_py()]
            img_data = img_struct.as_py()
            out_dir = os.path.join(pacs_dir, domain_name, cls)
            os.makedirs(out_dir, exist_ok=True)
            filename = os.path.basename(img_data.get("path", "img.jpg"))
            out_path = os.path.join(out_dir, filename)
            if not os.path.exists(out_path):
                with open(out_path, "wb") as f:
                    f.write(img_data["bytes"])


def create_balanced_subset(dataset, class_counts):
    min_samples_per_class = min(class_counts.values())

    indices_per_class = defaultdict(list)
    for idx, (_, label,_) in enumerate(dataset):
        indices_per_class[label].append(idx)

    balanced_indices = []
    for label, indices in indices_per_class.items():
        balanced_indices.extend(indices[:min_samples_per_class])

    return torch.utils.data.Subset(dataset, balanced_indices)

def count_samples_per_class(dataset):
    class_counts = defaultdict(int)
    for _, label,_ in dataset:
        class_counts[label] += 1
    return class_counts

# ---- Domain Split Utilities ----
def _get_base_domains(dataset_name):
    if dataset_name == 'domain_digits':
        return ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M']
    elif dataset_name == 'office-caltech10':
        return ['amazon', 'caltech', 'dslr', 'webcam']
    elif dataset_name == 'DomainNet':
        return ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
    elif dataset_name == 'PACS':
        return ['art_painting', 'cartoon', 'photo', 'sketch']
    else:
        return []

def _parse_split_counts(args, num_domains):
    """Return a list describing how many clients each domain is split into.

    By default the behaviour is controlled via ``domain_split_config`` or
    ``domain_split_factor``.  When ``bkd_domain_idx`` is set to a valid domain
    index, however, we ensure that the corresponding domain is only split into a
    single client while all other domains are duplicated ``domain_times_factor``
    times.  ``domain_split_factor`` is aligned with ``domain_times_factor`` to
    keep existing path conventions intact.
    """

    bkd_idx = getattr(args, "bkd_domain_idx", -1)
    if 0 <= bkd_idx < num_domains:
        factor = max(1, int(getattr(args, "domain_times_factor", 1)))
        counts = [factor for _ in range(num_domains)]
        counts[bkd_idx] = 1
        # ``domain_split_factor`` is used for naming paths.  Update it to match
        # the new replication factor so that the generated paths remain
        # compatible with previous conventions.
        args.domain_split_factor = args.domain_times_factor
    elif getattr(args, 'domain_split_config', ''):
        parts = re.split('[,:]', args.domain_split_config)
        counts = [int(x) for x in parts if x]
        if len(counts) != num_domains:
            raise ValueError('domain_split_config length mismatch')
    else:
        factor = max(1, int(getattr(args, 'domain_split_factor', 1)))
        counts = [factor for _ in range(num_domains)]
    return counts

def _split_loaders(args, train_loaders, test_loaders, base_names):
    """Split each domain's train loader into multiple clients.

    When splitting we try to balance the number of samples for each class
    across the clients belonging to the same domain."""

    counts = _parse_split_counts(args, len(base_names))
    if all(c == 1 for c in counts):
        args.dataset_names = base_names
        args.num_users = len(base_names)
        return train_loaders, test_loaders

    new_train, new_test, names = [], [], []
    for name, tr_loader, te_loader, cnt in zip(base_names, train_loaders, test_loaders, counts):
        train_dataset = tr_loader.dataset
        test_dataset = te_loader.dataset

        label_to_indices_train = defaultdict(list)
        for idx in range(len(train_dataset)):
            sample = train_dataset[idx]
            label = sample[1]
            label_to_indices_train[label].append(idx)

        label_to_indices_test = defaultdict(list)
        for idx in range(len(test_dataset)):
            sample = test_dataset[idx]
            label = sample[1]
            label_to_indices_test[label].append(idx)

        min_train = min((len(v) for v in label_to_indices_train.values()), default=0)
        min_test = min((len(v) for v in label_to_indices_test.values()), default=0)
        effective_cnt = min(cnt, min_train, min_test)
        if effective_cnt < 1:
            effective_cnt = 1
        if effective_cnt < cnt:
            print(
                f"Reducing split count for domain {name} from {cnt} to {effective_cnt} due to limited samples"
            )
        cnt = effective_cnt

        subset_indices_train = [list() for _ in range(cnt)]
        subset_indices_test = [list() for _ in range(cnt)]
        for label in set(label_to_indices_train.keys()) | set(label_to_indices_test.keys()):
            idxs_train = label_to_indices_train.get(label, [])
            idxs_test = label_to_indices_test.get(label, [])
            random.shuffle(idxs_train)
            random.shuffle(idxs_test)
            train_chunks = np.array_split(idxs_train, cnt) if idxs_train else [np.array([], dtype=int)] * cnt
            test_chunks = np.array_split(idxs_test, cnt) if idxs_test else [np.array([], dtype=int)] * cnt
            for i in range(cnt):
                subset_indices_train[i].extend(train_chunks[i].tolist())
                subset_indices_test[i].extend(test_chunks[i].tolist())

        shuffle = getattr(tr_loader, 'shuffle', True)
        for i in range(cnt):
            tr_subset = Subset(train_dataset, subset_indices_train[i])
            te_subset = Subset(test_dataset, subset_indices_test[i])
            new_train.append(
                DataLoader(tr_subset, batch_size=tr_loader.batch_size, shuffle=shuffle)
            )
            new_test.append(
                DataLoader(te_subset, batch_size=te_loader.batch_size, shuffle=False)
            )
            names.append(f"{name}_{i}" if cnt > 1 else name)

    args.dataset_names = names
    args.num_users = len(names)
    return new_train, new_test


def _get_domain(name):
    """Return the base domain name for a dataset."""
    parts = name.split('_')
    if len(parts) > 1 and parts[-1].isdigit():
        return '_'.join(parts[:-1])
    return name

def data_processing():
    df1 = pd.read_csv('list_attr_celeba.txt', sep="\s+", skiprows=1, usecols=['Male'])
    df1.loc[df1['Male'] == -1, 'Male'] = 0

    df2 = pd.read_csv('list_eval_partition.txt', sep="\s+", skiprows=0, header=None)
    df2.columns = ['Filename', 'Partition']
    df2 = df2.set_index('Filename')

    df3 = df1.merge(df2, left_index=True, right_index=True)
    df3.head()


    df3.to_csv('celeba-gender-partitions.csv')
    df4 = pd.read_csv('celeba-gender-partitions.csv', index_col=0)
    df4.head()


    df4.loc[df4['Partition'] == 0].to_csv('celeba-gender-train.csv')
    df4.loc[df4['Partition'] == 1].to_csv('celeba-gender-valid.csv')
    df4.loc[df4['Partition'] == 2].to_csv('celeba-gender-test.csv')

class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None):
        df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df.index.values
        self.y = df['Male'].values
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,self.img_names[index]))

        if self.transform is not None:
            img = self.transform(img)

        label = self.y[index]
        return img, label

    def __len__(self):
        return self.y.shape[0]

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        if self.transform:
            img = transforms.ToPILImage()(img)
            img = self.transform(img)
        return img, label,idx

def init_data(args,all_data = True,evals=False):

    if args.dataset == 'mnist':
        # trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        trans_mnist = transforms.Compose([transforms.ToTensor()])
        dataset_train = datasets.MNIST('./data/MNIST/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('./data/MNIST/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    elif args.dataset == 'fmnist':
        trans_fmnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.FashionMNIST('../data/fmnist', train=True, download=True, transform=trans_fmnist)
        dataset_test = datasets.FashionMNIST('../data/fmnist', train=False, download=True, transform=trans_fmnist)
        if args.iid:
            dict_users = fmnist_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in FMNIST')
    elif args.dataset == 'svhn':
        trans_svhn = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.SVHN('../data/svhn', split='train', download=True, transform=trans_svhn)
        dataset_test = datasets.SVHN('../data/svhn', split='test', download=True, transform=trans_svhn)
        # dataset_extra = datasets.SVHN('../data/svhn', split='extra', transform=trans_svhn,
        #                        target_transform=None, download=True)
        if args.iid:
            dict_users = svhn_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in SVHN')
    elif args.dataset == 'domain_digits':
        resize = 32
        transform_mnist = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([resize, resize]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_svhn = transforms.Compose([
            transforms.Resize([resize, resize]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_usps = transforms.Compose([
            transforms.Resize([resize,resize]),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_synth = transforms.Compose([
            transforms.Resize([resize, resize]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_mnistm = transforms.Compose([
            transforms.Resize([resize, resize]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # MNIST
        # mnist_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1, percent=args.percent,
        #                                           train=True, transform=transform_mnist) #org

        mnist_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1, percent=args.percent,
                                                  train=True,
                                                  transform=transform_mnist,
                                                  )  # backdoor

        mnist_testset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1, percent=args.percent,
                                                 train=False, transform=transform_mnist)

        # SVHN
        svhn_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/SVHN', channels=3, percent=args.percent,
                                                 train=True,
                                                 transform=transform_svhn)
        svhn_testset = data_utils.DigitsDataset(data_path='./data/digitdata/SVHN', channels=3, percent=args.percent,
                                                train=False,
                                                transform=transform_svhn)

        # USPS
        usps_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/USPS', channels=1, percent=args.percent,
                                                 train=True,
                                                 transform=transform_usps)
        usps_testset = data_utils.DigitsDataset(data_path='./data/digitdata/USPS', channels=1, percent=args.percent,
                                                train=False,
                                                transform=transform_usps)

        # Synth Digits
        synth_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/SynthDigits/', channels=3,
                                                  percent=args.percent,
                                                  train=True, transform=transform_synth)
        synth_testset = data_utils.DigitsDataset(data_path='./data/digitdata/SynthDigits/', channels=3,
                                                 percent=args.percent,
                                                 train=False, transform=transform_synth)

        # MNIST_M
        mnistm_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=args.percent,
                                                   train=True, transform=transform_mnistm)
        mnistm_testset = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=args.percent,
                                                  train=False, transform=transform_mnistm)

        mnistm_trainset_2 = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=-1,
                                                     train=True, transform=transform_mnistm)
        mnistm_testset_2 = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=-1,
                                                    train=False, transform=transform_mnistm)

        datasets = [ mnist_testset, svhn_testset ,usps_testset, synth_testset, mnistm_testset]
        class_counts_per_domain = [count_samples_per_class(dataset) for dataset in datasets]
        balanced_datasets = [create_balanced_subset(dataset, counts) for dataset, counts in zip(datasets, class_counts_per_domain)]
        balanced_loaders = [DataLoader(balanced_dataset, batch_size=args.local_bs, shuffle=False) for balanced_dataset in balanced_datasets]


        mnist_train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=args.local_bs, shuffle=True)
        mnist_test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=args.local_bs, shuffle=False)
        svhn_train_loader = torch.utils.data.DataLoader(svhn_trainset, batch_size=args.local_bs, shuffle=True)
        svhn_test_loader = torch.utils.data.DataLoader(svhn_testset, batch_size=args.local_bs, shuffle=False)
        usps_train_loader = torch.utils.data.DataLoader(usps_trainset, batch_size=args.local_bs, shuffle=True)
        usps_test_loader = torch.utils.data.DataLoader(usps_testset, batch_size=args.local_bs, shuffle=False)
        synth_train_loader = torch.utils.data.DataLoader(synth_trainset, batch_size=args.local_bs, shuffle=True)
        synth_test_loader = torch.utils.data.DataLoader(synth_testset, batch_size=args.local_bs, shuffle=False)
        mnistm_train_loader = torch.utils.data.DataLoader(mnistm_trainset, batch_size=args.local_bs, shuffle=True)
        mnistm_test_loader = torch.utils.data.DataLoader(mnistm_testset, batch_size=args.local_bs, shuffle=False)
        mnistm_train_loader_2 = torch.utils.data.DataLoader(mnistm_trainset_2, batch_size=args.local_bs, shuffle=True)
        mnistm_test_loader_2 = torch.utils.data.DataLoader(mnistm_testset_2, batch_size=args.local_bs, shuffle=False)

        train_loaders = [mnist_train_loader, svhn_train_loader, usps_train_loader, synth_train_loader, mnistm_train_loader]
        test_loaders = [mnist_test_loader, svhn_test_loader, usps_test_loader, synth_test_loader, mnistm_test_loader]
        test_loaders = balanced_loaders

        base_names = ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M']
        train_loaders, test_loaders = _split_loaders(args, train_loaders, test_loaders, base_names)

        backdoorloader = None
        if args.verify == "backdoor":
            bd_indices = (
                args.backdoor_client_idx
                if isinstance(args.backdoor_client_idx, (list, tuple))
                else [args.backdoor_client_idx]
            )
            for b_idx in bd_indices:
                domain = _get_domain(args.dataset_names[b_idx])
                if domain in ['MNIST', 'USPS']:
                    channel = 1
                    transform = transforms.Compose([
                        transforms.Grayscale(num_output_channels=3),
                        transforms.Resize([resize, resize]),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ])
                else:
                    channel = 3
                    transform = transforms.Compose([
                        transforms.Resize([resize, resize]),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ])

                backdoor_trainset = data_utils.DigitsDataset(
                    data_path=f"./data/digitdata/{domain}",
                    channels=channel,
                    percent=args.percent,
                    train=True,
                    transform=transform,
                    inject_backdoor=True,
                    load_backdoor=False,
                    args=args,
                    dataset=domain,
                    backdoortest=False
                )

                indices = train_loaders[b_idx].dataset.indices if isinstance(train_loaders[b_idx].dataset, Subset) else list(range(len(train_loaders[b_idx].dataset)))
                subset = Subset(backdoor_trainset, indices)
                train_loaders[b_idx] = DataLoader(subset, batch_size=args.local_bs, shuffle=True)

                if backdoorloader is None:
                    trainset_backdoor_test = data_utils.DigitsDataset(
                        data_path=f"./data/digitdata/{domain}",
                        channels=channel,
                        percent=args.percent,
                        inject_backdoor=True,
                        args=args,
                        load_backdoor=True,
                        backdoortest=True,
                        dataset=domain,
                        train=True,
                        transform=transform
                    )
                    backdoorloader = DataLoader(trainset_backdoor_test, batch_size=args.local_bs, shuffle=False)

        return train_loaders, test_loaders, backdoorloader

    elif args.dataset == 'office-caltech10':
        data_base_path = './data'
        resize = 224 # <-- Get the resize value
        transform_office = transforms.Compose([
            transforms.Resize([resize, resize]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation((-30, 30)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_test = transforms.Compose([
            transforms.Resize([resize, resize]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        datasets_name = ['amazon', 'caltech', 'dslr', 'webcam']

        # Load datasets - Pass backdoor related args and resize
        amazon_trainset_orig = data_utils.OfficeDataset(data_base_path, 'amazon', train=True, transform=transform_office, args=args, dataset_name='amazon', resize=resize) # Pass resize
        amazon_testset = data_utils.OfficeDataset(data_base_path, 'amazon', train=False, transform=transform_test, args=args, dataset_name='amazon', resize=resize) # Pass resize
        caltech_trainset_orig = data_utils.OfficeDataset(data_base_path, 'caltech', train=True, transform=transform_office, args=args, dataset_name='caltech', resize=resize) # Pass resize
        caltech_testset = data_utils.OfficeDataset(data_base_path, 'caltech', train=False, transform=transform_test, args=args, dataset_name='caltech', resize=resize) # Pass resize
        dslr_trainset_orig = data_utils.OfficeDataset(data_base_path, 'dslr', train=True, transform=transform_office, args=args, dataset_name='dslr', resize=resize) # Pass resize
        dslr_testset = data_utils.OfficeDataset(data_base_path, 'dslr', train=False, transform=transform_test, args=args, dataset_name='dslr', resize=resize) # Pass resize
        webcam_trainset_orig = data_utils.OfficeDataset(data_base_path, 'webcam', train=True, transform=transform_office, args=args, dataset_name='webcam', resize=resize) # Pass resize
        webcam_testset = data_utils.OfficeDataset(data_base_path, 'webcam', train=False, transform=transform_test, args=args, dataset_name='webcam', resize=resize) # Pass resize

        train_sets = [amazon_trainset_orig, caltech_trainset_orig, dslr_trainset_orig, webcam_trainset_orig]
        test_sets = [amazon_testset, caltech_testset, dslr_testset, webcam_testset]

        # Create train/test loaders
        train_loaders = [DataLoader(ds, batch_size=args.local_bs, shuffle=True if evals is False else False) for ds in train_sets]
        test_loaders = [DataLoader(ds, batch_size=args.local_bs, shuffle=False) for ds in test_sets]



        backdoorloader = None

        if args.verify == "backdoor":
            counts = _parse_split_counts(args, len(datasets_name))
            client_to_domain = []
            for idx, cnt in enumerate(counts):
                client_to_domain.extend([idx] * cnt)

            domain_indices = set()
            bd_indices = (
                args.backdoor_client_idx
                if isinstance(args.backdoor_client_idx, (list, tuple))
                else [args.backdoor_client_idx]
            )
            for b_idx in bd_indices:
                if b_idx < 0 or b_idx >= len(client_to_domain):
                    raise IndexError("backdoor_client_idx out of range")
                domain_indices.add(client_to_domain[b_idx])

            for d_idx in domain_indices:
                client_site_name = datasets_name[d_idx]

                backdoor_trainset = data_utils.OfficeDataset(
                    data_base_path, client_site_name, train=True, transform=transform_office,
                    inject_backdoor=True, load_backdoor=True, args=args, dataset_name=client_site_name,
                    backdoortest=False, resize=resize
                )
                train_loaders[d_idx] = torch.utils.data.DataLoader(
                    backdoor_trainset, batch_size=args.local_bs, shuffle=True
                )

                backdoor_testset_orig = data_utils.OfficeDataset(
                    data_base_path, client_site_name, train=False, transform=transform_test,
                    inject_backdoor=True, load_backdoor=True, args=args, dataset_name=client_site_name,
                    backdoortest=True, resize=resize
                )
                if backdoorloader is None:
                    backdoorloader = torch.utils.data.DataLoader(
                        backdoor_testset_orig, batch_size=args.local_bs, shuffle=False
                    )

        # Print dataset sizes
        for i, domain in enumerate(datasets_name):
            print("Dataset {} Trainset size: {}".format(domain, len(train_loaders[i].dataset)))
            print("Dataset {} Testset size: {}".format(domain, len(test_loaders[i].dataset)))
        if backdoorloader:
             print("Backdoor Testset size (ASR): {}".format(len(backdoorloader.dataset)))

        train_loaders, test_loaders = _split_loaders(args, train_loaders, test_loaders, datasets_name)
        return train_loaders, test_loaders, backdoorloader

    elif args.dataset == 'DomainNet':
        data_base_path = './data/domainnet'
        if  evals is True:
                transform_train = transforms.Compose([
                transforms.Resize([224, 224]),
                # transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation((-30, 30)),
                transforms.ToTensor(),

                ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation((-30, 30)),
                transforms.ToTensor(),

            ])

        transform_test = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor()
        ])

        # clipart
        clipart_trainset = DomainNetDataset(data_base_path, 'clipart', transform=transform_train)
        clipart_testset = DomainNetDataset(data_base_path, 'clipart', transform=transform_test, train=False)
        # infograph
        infograph_trainset = DomainNetDataset(data_base_path, 'infograph', transform=transform_train)
        infograph_testset = DomainNetDataset(data_base_path, 'infograph', transform=transform_test, train=False)
        # painting
        painting_trainset = DomainNetDataset(data_base_path, 'painting', transform=transform_train)
        painting_testset = DomainNetDataset(data_base_path, 'painting', transform=transform_test, train=False)
        # quickdraw
        quickdraw_trainset = DomainNetDataset(data_base_path, 'quickdraw', transform=transform_train)
        quickdraw_testset = DomainNetDataset(data_base_path, 'quickdraw', transform=transform_test, train=False)
        # real
        real_trainset = DomainNetDataset(data_base_path, 'real', transform=transform_train)
        real_testset = DomainNetDataset(data_base_path, 'real', transform=transform_test, train=False)
        # sketch
        sketch_trainset = DomainNetDataset(data_base_path, 'sketch', transform=transform_train)
        sketch_testset = DomainNetDataset(data_base_path, 'sketch', transform=transform_test, train=False)

        min_data_len = min(len(clipart_trainset), len(infograph_trainset), len(painting_trainset),
                           len(quickdraw_trainset), len(real_trainset), len(sketch_trainset))
        # val_len = int(min_data_len * 0.05)
        # min_data_len = int(min_data_len * 0.05)
        if all_data is False:
            min_data_len =1000

        clipart_trainset = torch.utils.data.Subset(clipart_trainset, list(range(min_data_len)))
        infograph_trainset = torch.utils.data.Subset(infograph_trainset, list(range(min_data_len)))
        painting_trainset = torch.utils.data.Subset(painting_trainset, list(range(min_data_len)))
        quickdraw_trainset = torch.utils.data.Subset(quickdraw_trainset, list(range(min_data_len)))
        real_trainset = torch.utils.data.Subset(real_trainset, list(range(min_data_len)))
        sketch_trainset = torch.utils.data.Subset(sketch_trainset, list(range(min_data_len)))

        datasets = [clipart_testset, infograph_testset , painting_testset, quickdraw_testset,  real_testset,  sketch_trainset]
        class_counts_per_domain = [count_samples_per_class(dataset) for dataset in datasets]
        balanced_datasets = [create_balanced_subset(dataset, counts) for dataset, counts in zip(datasets, class_counts_per_domain)]
        balanced_loaders = [DataLoader(balanced_dataset, batch_size=args.local_bs, shuffle=False) for balanced_dataset in balanced_datasets]


        if  evals is True:
            shuffles = False
        else:
            shuffles = True
        clipart_train_loader = torch.utils.data.DataLoader(clipart_trainset, batch_size=args.local_bs, shuffle=shuffles)
        clipart_test_loader = torch.utils.data.DataLoader(clipart_testset, batch_size=args.local_bs, shuffle=False)

        infograph_train_loader = torch.utils.data.DataLoader(infograph_trainset, batch_size=args.local_bs, shuffle=shuffles)
        infograph_test_loader = torch.utils.data.DataLoader(infograph_testset, batch_size=args.local_bs, shuffle=False)

        painting_train_loader = torch.utils.data.DataLoader(painting_trainset, batch_size=args.local_bs, shuffle=shuffles)
        painting_test_loader = torch.utils.data.DataLoader(painting_testset, batch_size=args.local_bs, shuffle=False)

        quickdraw_train_loader = torch.utils.data.DataLoader(quickdraw_trainset, batch_size=args.local_bs, shuffle=shuffles)
        quickdraw_test_loader = torch.utils.data.DataLoader(quickdraw_testset, batch_size=args.local_bs, shuffle=False)

        real_train_loader = torch.utils.data.DataLoader(real_trainset, batch_size=args.local_bs, shuffle=shuffles)
        real_test_loader = torch.utils.data.DataLoader(real_testset, batch_size=args.local_bs, shuffle=False)

        sketch_train_loader = torch.utils.data.DataLoader(sketch_trainset, batch_size=args.local_bs, shuffle=shuffles)
        sketch_test_loader = torch.utils.data.DataLoader(sketch_testset, batch_size=args.local_bs, shuffle=False)

        train_loaders = [clipart_train_loader, infograph_train_loader, painting_train_loader, quickdraw_train_loader,
                         real_train_loader, sketch_train_loader]
        # test_loaders = [clipart_test_loader, infograph_test_loader, painting_test_loader, quickdraw_test_loader,
        #                 real_test_loader, sketch_test_loader]
        # test_loaders = balanced_loaders
        test_loaders = [
            DataLoader(clipart_testset, batch_size=args.local_bs, shuffle=False),
            DataLoader(infograph_testset, batch_size=args.local_bs, shuffle=False),
            DataLoader(painting_testset, batch_size=args.local_bs, shuffle=False),
            DataLoader(quickdraw_testset, batch_size=args.local_bs, shuffle=False),
            DataLoader(real_testset, batch_size=args.local_bs, shuffle=False),
            DataLoader(sketch_testset, batch_size=args.local_bs, shuffle=False)
        ]

        datasets = ['clipart','infograph','painting','quickdraw','real','sketch']
        if args.verify == "backdoor":
            backdoor_loader = None
            bd_indices = (
                args.backdoor_client_idx
                if isinstance(args.backdoor_client_idx, (list, tuple))
                else [args.backdoor_client_idx]
            )
            for b_idx in bd_indices:
                backdoor_trainset = DomainNetDataset(data_base_path, datasets[b_idx], transform=None)
                backdoor_trainset = torch.utils.data.Subset(backdoor_trainset, list(range(min_data_len)))
                if backdoor_loader is None:
                    backdoor_loader = torch.utils.data.DataLoader(backdoor_trainset, batch_size=args.local_bs, shuffle=False)
        else:
            backdoor_loader = None
        # for i,domain in enumerate(datasets):
        #     print("Dataset {} Trainset size: {}".format(domain, len(train_loaders[i].dataset)))
        #     print("Dataset {} Testset size: {}".format(domain, len(test_loaders[i].dataset)))
        train_loaders, test_loaders = _split_loaders(args, train_loaders, test_loaders, datasets)
        return train_loaders, test_loaders, backdoor_loader
    elif args.dataset == 'PACS':
        inject_backdoor = args.verify == "backdoor"
        train_loaders, test_loaders, backdoor_loader = create_pacs_datasets(
            args, inject_backdoor=inject_backdoor, evals=evals
        )
        return train_loaders, test_loaders, backdoor_loader
    else:
        exit('Error: unrecognized dataset')

def init_model(args):

    if args.model == 'cnn' and args.dataset == 'domain_digits':
        net_glob = DigitModel(dataset=args.dataset).to(args.device)
    elif args.model == 'vgg16' and args.dataset == 'domain_digits':
        net_glob = vgg(dataset='domain_digits', depth=16, init_weights=True, cfg=None).to(args.device)

    elif args.model == 'vgg16' and args.dataset == 'office-caltech10':
        net_glob = vgg(dataset='office-caltech10', depth=16, init_weights=True, cfg=None).to(args.device)
    elif args.model == 'resnet18' and args.dataset == 'office-caltech10':
        net_glob = ResNet18(dataset=args.dataset).to(args.device)

    elif args.model == 'vgg16' and args.dataset == 'DomainNet':
        net_glob = vgg(dataset='DomainNet', depth=16, init_weights=True, cfg=None).to(args.device)
    elif args.model == 'resnet18' and args.dataset == 'DomainNet':
        net_glob = ResNet18(dataset=args.dataset, num_classes=args.num_classes).to(args.device)
    elif args.model == 'resnet18' and args.dataset == 'PACS':
        net_glob = ResNet18(dataset=args.dataset, num_classes=args.num_classes).to(args.device)
    elif args.model == 'vgg16' and args.dataset == 'PACS':
        net_glob = vgg(dataset='PACS', depth=16, init_weights=True, cfg=None)
        if hasattr(net_glob, 'classifier') and isinstance(net_glob.classifier, nn.Sequential):
            last_in = net_glob.classifier[-1].in_features
            if last_in:
                net_glob.classifier[-1] = nn.Linear(last_in, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'cnn' and args.dataset == 'PACS':
        net_glob = DigitModel(dataset=args.dataset, num_classes=args.num_classes).to(args.device)
    elif args.model == 'vit' and args.dataset == 'PACS':
        model_name = 'google/vit-base-patch16-224'
        net_glob = ViTForImageClassification.from_pretrained(model_name)
        if args.num_classes != net_glob.config.num_labels:
            net_glob.classifier = nn.Linear(net_glob.classifier.in_features, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'mobilevit' and args.dataset == 'PACS':
        net_glob = timm.create_model('mobilevit_s', pretrained=True)
        if hasattr(net_glob, 'classifier'):
            in_f = net_glob.classifier.in_features
            net_glob.classifier = nn.Linear(in_f, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'resnet50' and args.dataset in ['DomainNet', 'office-caltech10', 'domain_digits']:
        from torchvision.models import resnet50, ResNet50_Weights
        net_glob = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        if args.num_classes != 1000:
            net_glob.fc = nn.Linear(net_glob.fc.in_features, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'vit' and args.dataset in ['DomainNet', 'office-caltech10', 'domain_digits']:
        model_name = 'google/vit-base-patch16-224'
        net_glob = ViTForImageClassification.from_pretrained(model_name)
        if args.num_classes != net_glob.config.num_labels:
            net_glob.classifier = nn.Linear(net_glob.classifier.in_features, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'mobilevit' and args.dataset in ['DomainNet', 'office-caltech10', 'domain_digits']:
        net_glob = timm.create_model('mobilevit_s', pretrained=True)
        if hasattr(net_glob, 'classifier'):
            in_f = net_glob.classifier.in_features
            net_glob.classifier = nn.Linear(in_f, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'resnet50':
        from torchvision.models import resnet50, ResNet50_Weights
        net_glob = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        if args.num_classes != 1000:
            net_glob.fc = nn.Linear(net_glob.fc.in_features, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'vit':
        model_name = 'google/vit-base-patch16-224'
        net_glob = ViTForImageClassification.from_pretrained(model_name)
        if args.num_classes != net_glob.config.num_labels:
            net_glob.classifier = nn.Linear(net_glob.classifier.in_features, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'mobilevit':
        net_glob = timm.create_model('mobilevit_s', pretrained=True)
        if hasattr(net_glob, 'classifier'):
            in_f = net_glob.classifier.in_features
            net_glob.classifier = nn.Linear(in_f, args.num_classes)
        net_glob = net_glob.to(args.device)
    elif args.model == 'mae_vit':
        if args.dataset == 'office-caltech10':
            model_fn = mae_vit_small
        else:
            model_fn = mae_vit_base
            
        net_glob = model_fn(
            img_size=args.img_size if hasattr(args, 'img_size') else 224,
            in_chans=3,
            norm_pix_loss=args.norm_pix_loss
        ).to(args.device)
        
        if hasattr(args, 'pretrained') and args.pretrained:
            load_pretrained_mae(net_glob, args.pretrained_path)
            
        if args.finetune:
            replace_classifier(net_glob, num_classes=10)
    # elif args.model == 'mlp':
    #     #img_size = dataset_train[0][0].shape
    #     img_size = train_loaders[0][0][0].shape
    #     len_in = 1
    #     for x in img_size:
    #         len_in *= x
    #     net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')

    # print(net_glob)
    return net_glob


def get_dataset(args):
    base = _get_base_domains(args.dataset)
    counts = _parse_split_counts(args, len(base))

    # When a backdoor domain index is specified, compute the absolute client
    # index and align both ``unlearning_client`` and ``backdoor_client_idx`` with
    # this value.  These fields are stored as lists in the rest of the codebase
    # so we keep that convention here for compatibility.
    bkd_idx = getattr(args, "bkd_domain_idx", -1)
    if 0 <= bkd_idx < len(base):
        bkd_client_idx = sum(counts[:bkd_idx])
        args.unlearning_client = [bkd_client_idx]
        args.backdoor_client_idx = [bkd_client_idx]

    names = []
    for n, c in zip(base, counts):
        if c == 1:
            names.append(n)
        else:
            for i in range(c):
                names.append(f'{n}_{i}')
    return getattr(args, 'dataset_names', names)
  
def print_domain_sample(datasets_name, dataloaders):

    for num, train_loader in enumerate(dataloaders):
        sum_cont = 0
        # Determine number of classes dynamically
        labels = []
        for i in range(len(train_loader.dataset)):
            sample = train_loader.dataset[i]
            label = sample[1]
            labels.append(label)
        num_classes = max(labels) + 1 if labels else 0
        count = [0 for _ in range(num_classes)]
        for lbl in labels:
            if lbl < len(count):
                count[lbl] += 1
        sum_cont = len(labels)
        print(f"Train Dataset {datasets_name[num]}, Each Class Number: {count} , All Number: {sum_cont}")


def sample_backdoor_dataset(dataset, n_train, target_label):
    """
    对数据集进行采样：
    - 对于非目标类别：如果该类别的样本数大于 n_train，则只保留 n_train 个样本，否则保留全部；
    - 对于目标类别（后门目标标签）：保留所有样本。
    """
    label_indices = defaultdict(list)
    for idx in range(len(dataset)):
        sample = dataset[idx]
        if len(sample) == 2:
            _, label = sample
        elif len(sample) == 3:
            _, label, _ = sample
        else:
            raise ValueError(f"Unexpected dataset format at index {idx}: {sample}")
        label_indices[label].append(idx)
    
    final_indices = []
    for label, indices in label_indices.items():
        if label == target_label:
            final_indices.extend(indices)
        else:
            if len(indices) > n_train:
                final_indices.extend(indices[:n_train])
            else:
                final_indices.extend(indices)
    return Subset(dataset, final_indices)

def init_data_methodone(args, all_data=True, evals=False):
    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))
    new_save_path = f'./save/datasets/{args.dataset}/{dsf_dir}/{bkd_str}/{args.verify}/'
    old_save_path = f'./save/datasets/{args.dataset}/{old_dsf_dir}/{bkd_str}/{args.verify}/'
    save_path = new_save_path
    if args.target != 'learning' and not os.path.exists(os.path.join(new_save_path, 'train_loaders.pth')) and os.path.exists(os.path.join(old_save_path, 'train_loaders.pth')):
        save_path = old_save_path
    os.makedirs(save_path, exist_ok=True)

    if args.dataset == 'domain_digits':

        if args.target == 'learning' and args.verify != 'backdoor':
            print("Creating new datasets for learning (non-backdoor)...")
            train_loaders, test_loaders, _ = create_datasets(args, inject_backdoor=False)
            print("Saving datasets...")
            torch.save(train_loaders, os.path.join(save_path, 'train_loaders.pth'))
            torch.save(test_loaders, os.path.join(save_path, 'test_loaders.pth'))
            backdoorloader = None
            return train_loaders, test_loaders, backdoorloader

        elif args.target == 'learning' and args.verify == 'backdoor':
            print("Creating new datasets for learning (backdoor)...")
            train_loaders, test_loaders, backdoorloader = create_datasets(args, inject_backdoor=True)
            print("Saving datasets...")
            torch.save(train_loaders, os.path.join(save_path, 'train_loaders.pth'))
            torch.save(test_loaders, os.path.join(save_path, 'test_loaders.pth'))
            torch.save(backdoorloader, os.path.join(save_path, 'backdoorloader.pth'))
            return train_loaders, test_loaders, backdoorloader

        elif args.target != 'learning' and args.verify != 'backdoor':
            print("Loading saved datasets for non-learning (non-backdoor)...")
            train_loaders = torch.load(os.path.join(save_path, 'train_loaders.pth'))
            test_loaders = torch.load(os.path.join(save_path, 'test_loaders.pth'))
            backdoorloader = None
            return train_loaders, test_loaders, backdoorloader

        elif args.target != 'learning' and args.verify == 'backdoor':
            print("Loading saved datasets for non-learning (backdoor)...")
            train_loaders, test_loaders, backdoorloader = create_datasets(args, inject_backdoor=True)
            return train_loaders, test_loaders, backdoorloader

    elif args.dataset == 'office-caltech10':
        if args.target == 'learning':
            print(f"Creating new {args.dataset} datasets...")
            train_loaders, test_loaders, backdoorloader = create_office_caltech10_datasets(
                args,
                inject_backdoor=(args.verify == 'backdoor')
            )
            torch.save(train_loaders, os.path.join(save_path, 'train_loaders.pth'))
            torch.save(test_loaders, os.path.join(save_path, 'test_loaders.pth'))
            if backdoorloader:
                torch.save(backdoorloader, os.path.join(save_path, 'backdoorloader.pth'))
        else:
            print(f"Loading saved {args.dataset} datasets...")
            train_loaders, test_loaders, backdoorloader = create_office_caltech10_datasets(
                args,
                inject_backdoor=(args.verify == 'backdoor')
            )
        return train_loaders, test_loaders, backdoorloader

    elif args.dataset == 'PACS':
        if args.target == 'learning':
            print("Creating new PACS datasets...")
            train_loaders, test_loaders, backdoorloader = create_pacs_datasets(
                args, inject_backdoor=False, evals=evals
            )
            torch.save(train_loaders, os.path.join(save_path, 'train_loaders.pth'))
            torch.save(test_loaders, os.path.join(save_path, 'test_loaders.pth'))
        else:
            print("Loading saved PACS datasets...")
            train_loaders, test_loaders, backdoorloader = create_pacs_datasets(
                args, inject_backdoor=False, evals=evals
            )
        return train_loaders, test_loaders, backdoorloader

    else:
        raise ValueError("Unrecognized target/verify configuration.")

def create_datasets(args, inject_backdoor=False):
    """
    构造数据集：
      - 当 inject_backdoor 为 False 时，构造普通数据集；
      - 当 inject_backdoor 为 True 时，构造数据集时对指定后门客户端注入后门，并对训练集进行采样：
          对于非目标类别，保留最多 args.n_train 个样本；
          对于目标类别（args.backdoor_target_label），保留全部后门样本。
    """
    resize = 32
    transform_mnist = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize([resize, resize]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_svhn = transforms.Compose([
        transforms.Resize([resize, resize]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_usps = transforms.Compose([
        transforms.Resize([resize, resize]),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_synth = transforms.Compose([
        transforms.Resize([resize, resize]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_mnistm = transforms.Compose([
        transforms.Resize([resize, resize]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    mnist_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1,
                                               percent=args.percent, train=True,
                                               transform=transform_mnist,
                                               inject_backdoor=False, args=args,
                                               dataset="MNIST", backdoortest=False)
    mnist_testset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1,
                                              percent=args.percent, train=False,
                                              transform=transform_mnist,
                                              inject_backdoor=False, args=args,
                                              dataset="MNIST", backdoortest=False)
    svhn_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/SVHN", channels=3,
                                              percent=args.percent, train=True,
                                              transform=transform_svhn,
                                              inject_backdoor=False, args=args,
                                              dataset="SVHN", backdoortest=False)
    svhn_testset = data_utils.DigitsDataset(data_path="./data/digitdata/SVHN", channels=3,
                                             percent=args.percent, train=False,
                                             transform=transform_svhn,
                                             inject_backdoor=False, args=args,
                                             dataset="SVHN", backdoortest=False)
    usps_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/USPS", channels=1,
                                              percent=args.percent, train=True,
                                              transform=transform_usps,
                                              inject_backdoor=False, args=args,
                                              dataset="USPS", backdoortest=False)
    usps_testset = data_utils.DigitsDataset(data_path="./data/digitdata/USPS", channels=1,
                                             percent=args.percent, train=False,
                                             transform=transform_usps,
                                             inject_backdoor=False, args=args,
                                             dataset="USPS", backdoortest=False)
    synth_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/SynthDigits", channels=3,
                                               percent=args.percent, train=True,
                                               transform=transform_synth,
                                               inject_backdoor=False, args=args,
                                               dataset="SynthDigits", backdoortest=False)
    synth_testset = data_utils.DigitsDataset(data_path="./data/digitdata/SynthDigits", channels=3,
                                              percent=args.percent, train=False,
                                              transform=transform_synth,
                                              inject_backdoor=False, args=args,
                                              dataset="SynthDigits", backdoortest=False)
    mnistm_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST_M", channels=3,
                                                percent=args.percent, train=True,
                                                transform=transform_mnistm,
                                                inject_backdoor=False, args=args,
                                                dataset="MNIST_M", backdoortest=False)
    mnistm_testset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST_M", channels=3,
                                               percent=args.percent, train=False,
                                               transform=transform_mnistm,
                                               inject_backdoor=False, args=args,
                                               dataset="MNIST_M", backdoortest=False)

    datasets = {
        0: (mnist_trainset, mnist_testset),
        1: (svhn_trainset, svhn_testset),
        2: (usps_trainset, usps_testset),
        3: (synth_trainset, synth_testset),
        4: (mnistm_trainset, mnistm_testset)
    }
    train_loaders = []
    test_loaders = []
    for client_idx in range(args.num_users):
        if client_idx in datasets:
            trainset, testset = datasets[client_idx]
            trainset = sample_dataset(trainset, args.n_train)
            train_loader = DataLoader(trainset, batch_size=args.local_bs, shuffle=True)
            test_loader = DataLoader(testset, batch_size=args.local_bs, shuffle=False)
            train_loaders.append(train_loader)
            test_loaders.append(test_loader)
        else:
            train_loaders.append(None)
            test_loaders.append(None)

    base_names = ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M']
    train_loaders, test_loaders = _split_loaders(args, train_loaders, test_loaders, base_names)

    backdoorloader = None
    if inject_backdoor:
        bd_indices = (
            args.backdoor_client_idx
            if isinstance(args.backdoor_client_idx, (list, tuple))
            else [args.backdoor_client_idx]
        )
        for b_idx in bd_indices:
            domain = _get_domain(args.dataset_names[b_idx])
            if domain in ['MNIST', 'USPS']:
                channel = 1
                transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=3),
                    transforms.Resize([resize, resize]),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
            else:
                channel = 3
                transform = transforms.Compose([
                    transforms.Resize([resize, resize]),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
            backdoor_trainset = data_utils.DigitsDataset(
                data_path=f"./data/digitdata/{domain}",
                channels=channel,
                percent=args.percent,
                train=True,
                transform=transform,
                inject_backdoor=True,
                load_backdoor=False,
                args=args,
                dataset=domain,
                backdoortest=False
            )
            indices = train_loaders[b_idx].dataset.indices if isinstance(train_loaders[b_idx].dataset, Subset) else list(range(len(train_loaders[b_idx].dataset)))
            subset = Subset(backdoor_trainset, indices)
            train_loaders[b_idx] = DataLoader(subset, batch_size=args.local_bs, shuffle=True)

            if backdoorloader is None:
                trainset_backdoor_test = data_utils.DigitsDataset(
                    data_path=f"./data/digitdata/{domain}",
                    channels=channel,
                    percent=args.percent,
                    inject_backdoor=True,
                    args=args,
                    load_backdoor=True,
                    backdoortest=True,
                    dataset=domain,
                    train=True,
                    transform=transform
                )
                backdoorloader = DataLoader(trainset_backdoor_test, batch_size=args.local_bs, shuffle=False)

    return train_loaders, test_loaders, backdoorloader

def sample_dataset(dataset, n_train):
    """
    对普通数据集进行采样，使得每个类别不超过 n_train 个样本（若该类别不足则全部保留）。
    """
    label_counts = defaultdict(int)
    indices = []
    for idx in range(len(dataset)):
        sample = dataset[idx]
        if len(sample) == 2:
            _, label = sample
        elif len(sample) == 3:
            _, label, _ = sample
        else:
            raise ValueError(f"Unexpected dataset format at index {idx}: {sample}")
        if label_counts[label] < n_train:
            indices.append(idx)
            label_counts[label] += 1
    return Subset(dataset, indices)


def preprocess_office_caltech10_4half(data_base_path, args, evals):
    transform_office = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation((-30, 30)),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor()
    ])

    datasets_name = ['amazon', 'caltech', 'dslr', 'webcam']
    train_loaders = []
    test_loaders = []

    for domain in datasets_name:
        domain_path = os.path.join(data_base_path, 'office_caltech_10', domain)
        trainset = ImageFolder(root=domain_path, transform=transform_office)
        testset = ImageFolder(root=domain_path, transform=transform_test)

        shuffle = not evals
        train_loader = DataLoader(trainset, batch_size=args.local_bs, shuffle=shuffle)
        test_loader = DataLoader(testset, batch_size=args.local_bs, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)

        print(f"Dataset {domain} Trainset size: {len(trainset)}")
        print(f"Dataset {domain} Testset size: {len(testset)}")

    backdoor_loader = None
    if args.verify == "backdoor":
        bd_indices = (
            args.backdoor_client_idx
            if isinstance(args.backdoor_client_idx, (list, tuple))
            else [args.backdoor_client_idx]
        )
        b_idx = bd_indices[0]
        backdoor_trainset = ImageFolder(
            root=os.path.join(data_base_path, 'office_caltech_10', datasets_name[b_idx]),
            transform=transform_office,
        )
        backdoor_loader = DataLoader(backdoor_trainset, batch_size=args.local_bs, shuffle=False)

    return train_loaders, test_loaders, backdoor_loader




class OfficeDataset(Dataset):
    def __init__(self, base_path, domain, transform=None, train=True, 
                 inject_backdoor=False, target_label=None, backdoortest=False,
                 n_train=None, seed=1234):
        if not os.path.exists(base_path):
            raise FileNotFoundError(f"数据集根目录不存在: {base_path}")
        domain_path = os.path.join(base_path, domain)
        if not os.path.isdir(domain_path):
            raise FileNotFoundError(f"域目录不存在: {domain_path}")
        
        self.transform = transform
        self.domain = domain
        self.classes = ['back_pack', 'bike', 'calculator', 'headphones',
                       'keyboard', 'laptop_computer', 'monitor', 'mouse',
                       'mug', 'projector']
        self.class_to_idx = {cls:i for i, cls in enumerate(self.classes)}
        self.data = []
        self.inject_backdoor = inject_backdoor
        self.target_label = target_label
        self.backdoortest = backdoortest
        self.n_train = n_train
        
        print(f"\n{'='*40}")
        print(f"初始化 {domain} 数据集 ({'训练' if train else '测试'})")
        print(f"路径: {domain_path}")

        np.random.seed(seed)
        for class_name in self.classes:
            class_dir = os.path.join(domain_path, class_name)
            if not os.path.exists(class_dir):
                print(f"  !! 警告：跳过缺失的类别目录 {class_name}")
                continue

            images = [
                os.path.join(class_dir, f)
                for f in os.listdir(class_dir)
                if f.lower().endswith((".jpg", ".jpeg", ".png"))
                and os.path.isfile(os.path.join(class_dir, f))
            ]

            if not images:
                print(f"  !! 警告：{class_name} 无有效图像")
                continue

            label = self.class_to_idx[class_name]
            np.random.shuffle(images)

            total = len(images)
            limit = int(total * 0.8)
            if n_train is not None:
                limit = min(limit, n_train)
            if total > 1:
                limit = min(max(1, limit), total - 1)
            else:
                limit = total

            train_imgs = images[:limit]
            test_imgs = images[limit:]
            selected = train_imgs if train else test_imgs

            self.data.extend([(p, label) for p in selected])
            print(
                f"  √ {class_name}: {'训练' if train else '测试'} {len(selected)}个样本"
            )

        if not self.data:
            raise ValueError(f"数据集为空！请检查: {domain_path}")
        print(f"总计加载 {len(self.data)} 个样本")
        print('='*40 + '\n')
                
        if backdoortest:
            self._create_backdoor_test_samples()
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        path, label = self.data[idx]
        image = Image.open(path).convert('RGB')
        
        if self.inject_backdoor:
            image = self._add_trigger(image)
            label = self.target_label
            
        if self.transform:
            image = self.transform(image)
            
        return image, label
    
    def _add_trigger(self, image):
        image = image.copy()
        width, height = image.size
        for i in range(max(0, width-3), width):
            for j in range(max(0, height-3), height):
                image.putpixel((i, j), (255, 255, 255))
        return image
    
    def _create_backdoor_test_samples(self):
        self.backdoor_data = []
        for path, _ in self.data:
            self.backdoor_data.append((path, self.target_label))
        self.data = self.backdoor_data

def create_office_caltech10_datasets(args, inject_backdoor=False, evals=False):
    required_params = ['n_train', 'local_bs', 'backdoor_client_idx', 'backdoor_target_label']
    if args.domain_skew:
        required_params.append('domain_skew_ratio')
        if args.domain_skew_ratio <= 0:
            raise ValueError("domain_skew_ratio必须大于0")
    
    for param in required_params:
        if not hasattr(args, param):
            raise ValueError(f"缺少必需参数: {param}")
    
    if args.n_train <= 0:
        raise ValueError(f"n_train必须>0，当前为{args.n_train}")
    
    data_base_path = './data/office_caltech_10'
    transform_train = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    transform_test = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    
    datasets_name = ['amazon', 'caltech', 'dslr', 'webcam']
    train_sets = []
    test_sets = []

    bd_indices = (
        args.backdoor_client_idx
        if isinstance(args.backdoor_client_idx, (list, tuple))
        else [args.backdoor_client_idx]
    )

    for idx, name in enumerate(datasets_name):
        try:
            original_n_train = args.n_train
            adjusted_n_train = original_n_train
            
            if args.domain_skew and name in ['amazon', 'caltech']:
                adjusted_n_train = int(original_n_train * args.domain_skew_ratio)
                print(f"* 应用域偏斜: {name} 每类样本数从 {original_n_train} 调整为 {adjusted_n_train}")

            dataset = OfficeDataset(
                base_path=data_base_path,
                domain=name,
                transform=transform_train,
                train=True,
                n_train=adjusted_n_train,
                inject_backdoor=inject_backdoor and (name in [datasets_name[i] for i in bd_indices]),
                target_label=args.backdoor_target_label,
                seed=args.seed if hasattr(args, 'seed') else 1234
            )
            train_sets.append(dataset)
            
            test_set = OfficeDataset(
                base_path=data_base_path,
                domain=name,
                transform=transform_test,
                train=False,
                n_train=adjusted_n_train
            )
            test_sets.append(test_set)
        except Exception as e:
            print(f"加载{name}数据集失败: {str(e)}")
            raise
    
    train_loaders = [
        DataLoader(ds, batch_size=args.local_bs, shuffle=not evals, pin_memory=True)
        for ds in train_sets
    ]
    test_loaders = [
        DataLoader(ds, batch_size=args.local_bs, shuffle=False, pin_memory=True)
        for ds in test_sets
    ]
    
    backdoor_loader = None
    if inject_backdoor:
        b_idx = bd_indices[0]
        backdoor_testset = OfficeDataset(
            base_path=data_base_path,
            domain=datasets_name[b_idx],
            transform=transform_test,
            train=False,
            backdoortest=True,
            target_label=args.backdoor_target_label,
            n_train=args.n_train
        )
        backdoor_loader = DataLoader(backdoor_testset, batch_size=args.local_bs, shuffle=False)
    
    print("\n数据集统计:")
    for i, name in enumerate(datasets_name):
        train_size = len(train_sets[i])
        test_size = len(test_sets[i])
        print(f"[{name.upper()}] 训练集: {train_size} samples | 测试集: {test_size} samples")
        if args.domain_skew and name in ['amazon', 'caltech']:
            expected = int(10 * args.n_train * args.domain_skew_ratio)
            print(f"  偏斜验证: 预期 {expected} → 实际 {train_size} (ratio={args.domain_skew_ratio})")
    
    return train_loaders, test_loaders, backdoor_loader


def create_pacs_datasets(args, inject_backdoor=False, evals=False):
    """Create PACS dataloaders using the local directory structure."""

    # Validate the backdoor target label before touching the filesystem or
    # instantiating any datasets. ``PACS`` has exactly 7 classes.
    num_classes = 7
    if getattr(args, "backdoor_target_label", 0) >= num_classes:
        print(
            f"backdoor_target_label {args.backdoor_target_label} exceeds number of classes ({num_classes}), resetting to 0"
        )
        args.backdoor_target_label = 0

    data_base_path = './data'
    pacs_dir = os.path.join(data_base_path, 'PACS')
    if not os.path.isdir(pacs_dir):
        print(f"PACS data not found at {pacs_dir}, downloading from {PACS_REPO}...")
        # `flwrlabs/pacs` is published as a dataset on the Hugging Face Hub.
        # Explicitly set `repo_type="dataset"` so `snapshot_download` fetches
        # the dataset instead of looking for a model repository.
        snapshot_download(
            repo_id=PACS_REPO,
            repo_type="dataset",
            local_dir=pacs_dir,
            local_dir_use_symlinks=False,
        )

    domains = ['art_painting', 'cartoon', 'photo', 'sketch']
    for d in domains:
        os.makedirs(os.path.join(pacs_dir, d), exist_ok=True)

    # If the directory exists but is empty (e.g. previous download failed),
    # trigger the download again.
    dataset_empty = True
    for d in domains:
        domain_dir = os.path.join(pacs_dir, d)
        if any(os.scandir(domain_dir)):
            dataset_empty = False
            break
    if dataset_empty:
        print(f"PACS data empty at {pacs_dir}, downloading from {PACS_REPO}...")
        snapshot_download(
            repo_id=PACS_REPO,
            repo_type="dataset",
            local_dir=pacs_dir,
            local_dir_use_symlinks=False,
        )
        _extract_pacs_images(pacs_dir)
        # re-check
        dataset_empty = True
        for d in domains:
            if any(os.scandir(os.path.join(pacs_dir, d))):
                dataset_empty = False
                break
        if dataset_empty:
            raise RuntimeError("Failed to prepare PACS dataset")

    transform_train = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation((-30, 30)),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
    ])

    train_sets = []
    test_sets = []
    bd_indices = (
        args.backdoor_client_idx
        if isinstance(args.backdoor_client_idx, (list, tuple))
        else [args.backdoor_client_idx]
    )
    # ``PACSDataset`` will perform the actual 8:2 train/test split.  Here we
    # simply forward the optional ``n_train`` value as an upper bound on the
    # number of training samples per class.
    n_train = getattr(args, "n_train", None)

    for idx, d in enumerate(domains):
        inject = inject_backdoor and (idx in bd_indices)
        train_sets.append(
            data_utils.PACSDataset(
                data_base_path,
                d,
                train=True,
                transform=transform_train,
                n_train=n_train,
                inject_backdoor=inject,
                target_label=args.backdoor_target_label,
                percent_poison=args.backdoor_percent_poison,
            )
        )
        test_sets.append(
            data_utils.PACSDataset(
                data_base_path,
                d,
                train=False,
                transform=transform_test,
                n_train=n_train,
            )
        )

    backdoor_loader = None
    if inject_backdoor:
        b_idx = bd_indices[0]
        backdoor_domain = domains[b_idx]
        backdoor_test = data_utils.PACSDataset(
            data_base_path,
            backdoor_domain,
            train=False,
            transform=transform_test,
            n_train=n_train,
            inject_backdoor=True,
            backdoortest=True,
            target_label=args.backdoor_target_label,
            percent_poison=args.backdoor_percent_poison,
        )
        backdoor_loader = DataLoader(backdoor_test, batch_size=args.local_bs, shuffle=False)

    train_loaders = [DataLoader(ds, batch_size=args.local_bs,
                                shuffle=not evals) for ds in train_sets]
    test_loaders = [DataLoader(ds, batch_size=args.local_bs,
                               shuffle=False) for ds in test_sets]

    train_loaders, test_loaders = _split_loaders(args, train_loaders, test_loaders, domains)

    return train_loaders, test_loaders, backdoor_loader

if __name__ == '__main__':
    data_processing()