import torch
#import torch_dataset_mirror
from torchvision import datasets, transforms
from torch.utils.data import Dataset,random_split
import os
from PIL import Image
import sys
from torch.utils.data.distributed import DistributedSampler

# Helper class for datasets requiring custom split handling
class TransformedSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

def load_data_mlp(dataset,batch_size,dim_factor=2,data_root="../data"):
    # Define standard transforms for new RGB datasets
    imnet_mean = [0.485, 0.456, 0.406]
    imnet_std = [0.229, 0.224, 0.225]
    transform_train_rgb = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(imnet_mean, imnet_std)
    ])
    transform_test_rgb = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(imnet_mean, imnet_std)
    ])    
    if dataset == 'EMNIST':
        # EMNIST train dataset
        train_loader = torch.utils.data.DataLoader(datasets.EMNIST(
            root=data_root,
            train=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1751,), (0.3332,))
            ]),
            download=True,
            split='balanced'),
            batch_size=batch_size,
            shuffle=True)

        # EMNIST test dataset
        test_loader = torch.utils.data.DataLoader(datasets.EMNIST(
            root=data_root,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1751,), (0.3332,))
            ]),
            download=True,
            split='balanced'),
            batch_size=batch_size,
            shuffle=False)

        indim = 784
        outdim = 47
        

    elif dataset == "MNIST":
        # MNIST train dataset
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(data_root, train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
            batch_size=batch_size, shuffle=True)
        # MNIST test dataset
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(data_root, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=batch_size, shuffle=False)
        indim = 784
        outdim = 10


    elif dataset == "Fashion_MNIST":
        # Fashion MNIST train dataset
        train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(
            root=data_root,
            train=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.2860,), (0.3530,))
            ]),
            download=True),
            batch_size=batch_size,
            shuffle=True)
        
        # Fashion MNIST test dataset
        test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(
            root=data_root,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.2860,), (0.3530,))
            ]),
            download=True),
            batch_size=batch_size,
            shuffle=False)
        indim = 784
        outdim = 10
    
    elif dataset == "CIFAR10":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(data_root, train=True, download=True,
                            transform=transforms.Compose([
                                transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])])),
            batch_size=batch_size, shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(data_root, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])])),
            batch_size=batch_size, shuffle=False)
        indim = 3072
        outdim = 10
        
        
    elif dataset == "CIFAR100":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(data_root, train=True, download=True,
                            transform=transforms.Compose([
                                transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2673, 0.2564, 0.2762])])),
            batch_size=batch_size, shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(data_root, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2673, 0.2564, 0.2762])])),
            batch_size=batch_size, shuffle=False)
        indim = 3072
        outdim = 100

    elif dataset == "FER2013":
        train_loader = torch.utils.data.DataLoader(
            datasets.FER2013(
            root=data_root,
            split='train',
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5077,), (0.2550,))
            ])),
        batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.FER2013(
            root=data_root,
            split='test',
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5077,), (0.2550,))
            ]),
        ),
        batch_size=batch_size, shuffle=False)
        indim = 48*48
        outdim = 7     

    elif dataset == "SVHN":
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
            root=os.path.join(data_root,'svhn'),
            split='train',
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4377, 0.4438, 0.4728],
                                    [0.198, 0.201, 0.197])
            ]),
            download=True
        ),
        batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
            root=os.path.join(data_root,'svhn'),
            split='test',
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4377, 0.4438, 0.4728],
                                    [0.198, 0.201, 0.197])
            ]),
            download=True
        ),
        batch_size=batch_size, shuffle=False)
        indim = 32*32*3
        outdim = 10              
        
    elif dataset == 'tiny-imagenet-ori':
        mean = [0.4802, 0.4481, 0.3975]
        std = [0.2296, 0.2263, 0.2255]
        train_loader = torch.utils.data.DataLoader(TinyImageNet_load(root=os.path.join(data_root,'tiny-imagenet-200'), train=True, transform=transforms.Compose([
                # transforms.RandomResizedCrop(64),
                transforms.RandomCrop(64, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])),
            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(TinyImageNet_load(root=os.path.join(data_root,'tiny-imagenet-200'), train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])),
            batch_size=batch_size, shuffle=False)
        indim = 64*64*3
        outdim = 200           

    elif dataset == 'tiny-imagenet-crop':
        mean = [0.4802, 0.4481, 0.3975]
        std = [0.2296, 0.2263, 0.2255]
        train_loader = torch.utils.data.DataLoader(TinyImageNet_load(root=os.path.join(data_root,'tiny-imagenet-200'), train=True, transform=transforms.Compose([
                transforms.Resize(64),
                transforms.RandomCrop(48),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])),
            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(TinyImageNet_load(root=os.path.join(data_root,'tiny-imagenet-200'), train=False, transform=transforms.Compose([
                transforms.Resize(64),
                transforms.CenterCrop(48),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])),
            batch_size=batch_size, shuffle=False)
        indim = 48*48*3
        outdim = 200                   

    elif dataset == 'tiny-imagenet-resize':
        mean = [0.4802, 0.4481, 0.3976]
        std = [0.2175, 0.2139, 0.2133]              
        train_loader = torch.utils.data.DataLoader(TinyImageNet_load(root=os.path.join(data_root,'tiny-imagenet-200'), train=True, transform=transforms.Compose([
                transforms.Resize(48, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])),
            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(TinyImageNet_load(root=os.path.join(data_root,'tiny-imagenet-200'), train=False, transform=transforms.Compose([
                transforms.Resize(48, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])),
            batch_size=batch_size, shuffle=False)
        indim = 48*48*3
        outdim = 200                           
    # ========== NEW DATASETS ADDITION START ========== #
    elif dataset == 'OxfordFlowers':
        # Oxford 102 Flowers Dataset
        from torchvision.datasets import Flowers102
        
        # Combine train and validation sets
        train_val = Flowers102(
            root=data_root, 
            split='train',
            download=True,
            transform=None
        )
        val_set = Flowers102(
            root=data_root, 
            split='val',
            download=True,
            transform=None
        )
        # Create combined training set (train + validation)
        train_set = torch.utils.data.ConcatDataset([train_val, val_set])
        test_set = Flowers102(
            root=data_root, 
            split='test',
            download=True,
            transform=transform_test_rgb
        )
        
        # Apply transforms to training subset
        train_set = TransformedSubset(train_set, transform=transform_train_rgb)
        
        # Create data loaders
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=batch_size, shuffle=True
        )
        test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=batch_size, shuffle=False
        )
        
        indim = 64 * 64 * 3
        outdim = 102  # 102 flower categories

    elif dataset == 'Caltech256':
        # Caltech-256 Dataset
        from torchvision.datasets import Caltech256
        def ensure_rgb(image):
            if image.mode != 'RGB':
                return image.convert('RGB')
            return image
        # Load full dataset without initial transform
        full_set = Caltech256(
            root=data_root, 
            download=False, 
            transform=None
        )
        transform_train_rgb_caltech = transforms.Compose([
            transforms.Lambda(ensure_rgb),
            transforms.Resize((64, 64)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(imnet_mean, imnet_std)
        ])
        transform_test_rgb_caltech = transforms.Compose([
            transforms.Lambda(ensure_rgb),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize(imnet_mean, imnet_std)
        ])    
        # Create 80/20 train-test split
        train_size = int(0.8 * len(full_set))
        test_size = len(full_set) - train_size
        train_sub, test_sub = random_split(
            full_set, 
            [train_size, test_size]
        )
        
        # Apply transforms to subsets
        train_set = TransformedSubset(train_sub, transform=transform_train_rgb_caltech)
        test_set = TransformedSubset(test_sub, transform=transform_test_rgb_caltech)
        
        # Create data loaders
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=batch_size, shuffle=True
        )
        test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=batch_size, shuffle=False
        )
        
        indim = 64 * 64 * 3
        outdim = 257  # 256 categories + 1 clutter class

    elif dataset == "OxfordIIITPet":

        # 加载数据集（自动下载）
        train_val_set = datasets.OxfordIIITPet(
            root=data_root,
            split='trainval',
            download=True,
            transform=transform_train_rgb
        )
        
        test_set = datasets.OxfordIIITPet(
            root=data_root,
            split='test',
            download=True,
            transform=transform_test_rgb
        )

        # 创建DataLoader
        train_loader = torch.utils.data.DataLoader(
            train_val_set,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4
        )
        
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4
        )
        
        indim = 64 * 64 * 3  # 输入维度（64x64 RGB图像）
        outdim = 37  # 37个宠物类别（12种猫+25种狗）:cite[4]:cite[8]

    elif dataset == 'INaturalist':
        # iNaturalist 2021 Dataset (Mini version)
        from torchvision.datasets import INaturalist
        
        train_loader = torch.utils.data.DataLoader(
            INaturalist(
                root=os.path.join(data_root,'INaturalist'),
                version='2021_train_mini',  # Manageable subset
                download=True,
                transform=transform_train_rgb
            ),
            batch_size=batch_size,
            shuffle=True
        )
        test_loader = torch.utils.data.DataLoader(
            INaturalist(
                root=os.path.join(data_root,'INaturalist'),
                version='2021_valid',  # Full validation set
                download=True,
                transform=transform_test_rgb
            ),
            batch_size=batch_size,
            shuffle=False
        )
        
        indim = 64 * 64 * 3
        outdim = 10000  # 10,000 species categories
    # ========== NEW DATASETS ADDITION END ========== #

    dimension = indim * dim_factor
    hiddim = [dimension, dimension, dimension]
    
    return train_loader, test_loader, indim, outdim, hiddim


class TinyImageNet_load(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")

        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()

        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

        self.set_nids = set()

        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))

        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]

    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3, 5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(self.train_dir, d))]
        classes = sorted(classes)
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images;

        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3, 5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(self.train_dir, d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file, 'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root, _, files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if (fname.endswith(".JPEG")):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)

    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, tgt
    

def load_data_cnn(dataset, batch_size, world_size, rank, data_root='../data'):
    assert batch_size%world_size == 0,'Batch size should be fully divided by world_size!'
    if dataset == "ImageNet":
        traindir = './data/train_raw'
        valdir = './data/test'
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize
            ])
        )
        test_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize
            ])
        )
        outdim = 1000
    elif dataset == 'CIFAR10':
        train_dataset = datasets.CIFAR10(data_root, train=True, download=True,
                            transform=transforms.Compose([
                                transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])]))
        test_dataset = datasets.CIFAR10(data_root, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])]))
        outdim = 10
    elif dataset == 'CIFAR100':
        train_dataset = datasets.CIFAR100(data_root, train=True, download=True,
                            transform=transforms.Compose([
                                transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2673, 0.2564, 0.2762])]))
        test_dataset = datasets.CIFAR100(data_root, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2673, 0.2564, 0.2762])]))
        outdim = 100
    elif dataset == 'tiny-imagenet-ori':
        mean = [0.4802, 0.4481, 0.3975]
        std = [0.2296, 0.2263, 0.2255]
        train_dataset = TinyImageNet_load(root="./data/tiny-imagenet-200/", train=True, transform=transforms.Compose([
                # transforms.RandomResizedCrop(64),
                transforms.RandomCrop(64, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]))
        test_dataset = TinyImageNet_load('./data/tiny-imagenet-200/', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]))
        outdim = 200           

    elif dataset == 'tiny-imagenet-crop':
        mean = [0.4802, 0.4481, 0.3975]
        std = [0.2296, 0.2263, 0.2255]
        train_dataset = TinyImageNet_load(root="./data/tiny-imagenet-200/", train=True, transform=transforms.Compose([
                transforms.Resize(64),
                transforms.RandomCrop(48),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]))
        test_dataset = TinyImageNet_load('./data/tiny-imagenet-200/', train=False, transform=transforms.Compose([
                transforms.Resize(64),
                transforms.CenterCrop(48),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]))
        outdim = 200                   

    elif dataset == 'tiny-imagenet-resize':
        mean = [0.4802, 0.4481, 0.3976]
        std = [0.2175, 0.2139, 0.2133]              
        train_dataset = TinyImageNet_load(root="./data/tiny-imagenet-200/", train=True, transform=transforms.Compose([
                transforms.Resize(48, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]))
        test_dataset = TinyImageNet_load('./data/tiny-imagenet-200/', train=False, transform=transforms.Compose([
                transforms.Resize(48, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]))
        outdim = 200        
    else:
        raise ValueError(f"Unsupported dataset {dataset}")

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size//world_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        sampler=train_sampler
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size//world_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        sampler=test_sampler
    )

    return train_loader, test_loader, outdim, train_sampler, test_sampler    


if __name__=='__main__':
    Datasets=["CIFAR100"] #"CIFAR10", , "EMNIST", "Fashion_MNIST", "MNIST", "FER2013", "SVHN", "tiny-imagenet-ori", "tiny-imagenet-crop", "tiny-imagenet-resize", "OxfordFlowers", "Caltech256", "OxfordIIITPet","INaturalist"

    for dataset in Datasets:
        train_loader,a,b,c,d=load_data_mlp(dataset,1)
        print(f'{dataset}: {len(train_loader)}')




