import os
import os.path as osp
import glob
import torch
from torch.utils.data import ConcatDataset
from torchvision import transforms

if __package__ is None:
    import sys
    from os import path

    sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))

from data.meta_dataset import MetaDataset, GetDataLoaderDict
from configs.default import digits_path

digits_name_dict = {
    'm': 'mnist',
    'mm': 'mnist_m',
    's': 'svhn',
    'sy': 'syn',
}

_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD)
])

transform_test = transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD)
])

def get_transforms():
    normalizer = transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD)
    base_aug = [transforms.RandomHorizontalFlip(), normalizer]
    global transform_train, transform_test
    transform_train = transforms.Compose([transforms.Resize([32, 32]), transforms.RandomCrop(32, padding=4), transforms.ToTensor()])
    transform_test = transforms.Compose([transforms.Resize([32, 32]), transforms.ToTensor(), normalizer])
    return base_aug, normalizer


def base_transforms(aug='standard'):
    if aug == 'aa':
        transform_train.transforms.insert(1, transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.SVHN))
    elif aug == 'ra':
        transform_train.transforms.insert(1, transforms.RandAugment(num_ops=3, magnitude=5))
    elif aug == 'cutout':
        normalize_index = next(
            (i for i, t in enumerate(transform_train.transforms) if isinstance(t, transforms.Normalize)), None)
        if normalize_index is not None:
            transform_train.transforms.insert(normalize_index, transforms.RandomErasing())
        else:
            transform_train.transforms.append(transforms.RandomErasing())


def listdir_nohidden(path, sort=False):
    """List non-hidden items in a directory.

    Args:
         path (str): directory path.
         sort (bool): sort the items.
    """
    items = [f for f in os.listdir(path) if not f.startswith(".")]
    if sort:
        items.sort()
    return items


class Digits_SingleDomain():
    def __init__(self, root_path=digits_path, domain_name='m', split='train', train_transform=None, seed=0):
        self.domain_name = domain_name
        assert domain_name in digits_name_dict.keys(), 'domain_name must be in {}'.format(
            digits_name_dict.keys())
        self.root_path = root_path
        self.domain = digits_name_dict[domain_name]
        self.domain_label = list(digits_name_dict.keys()).index(domain_name)

        self.split = split
        assert self.split in ['train', 'val', 'test'], 'split must be train, val or test'

        if train_transform is not None:
            self.transform = train_transform
        else:
            self.transform = transform_test
        self.seed = seed

        self.imgs, self.labels = Digits_SingleDomain.read_data(self.root_path, self.domain, self.split)
        self.dataset = MetaDataset(self.imgs, self.labels, self.domain_label, self.transform)  # get数据集

    @staticmethod
    def read_data(dataset_dir, domain, split):
        def _load_data_from_directory(directory):
            folders = listdir_nohidden(directory)
            folders.sort()
            imgs_, labels_ = [], []
            for label, folder in enumerate(folders):
                impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
                for impath in impaths:
                    imgs_.append(impath)
                    labels_.append(label)
            return imgs_, labels_

        if split == "test":
            train_dir = osp.join(dataset_dir, domain, "train")
            imgs_train, labels_train = _load_data_from_directory(train_dir)
            val_dir = osp.join(dataset_dir, domain, "val")
            imgs_val, labels_val = _load_data_from_directory(val_dir)
            imgs, labels = imgs_train + imgs_val, labels_train + labels_val
        else:
            split_dir = osp.join(dataset_dir, domain, split)
            imgs, labels = _load_data_from_directory(split_dir)

        return imgs, labels


class Digits_FedDG():
    def __init__(self, test_domain='m', batch_size=128, seed=0):
        self.batch_size = batch_size
        self.domain_list = list(digits_name_dict.keys())
        self.test_domain = test_domain
        self.train_domain_list = self.domain_list.copy()
        self.train_domain_list.remove(self.test_domain)
        self.seed = seed

        self.site_dataset_dict = {}
        self.site_dataloader_dict = {}
        for domain_name in self.domain_list:
            self.site_dataloader_dict[domain_name], self.site_dataset_dict[domain_name] = Digits_FedDG.SingleSite(
                domain_name, self.batch_size, self.seed)

        self.val_dataset = ConcatDataset([self.site_dataset_dict[i]['val'] for i in self.train_domain_list])
        self.val_dataloader = torch.utils.data.DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False,
                                                          drop_last=False, num_workers=8, pin_memory=True)

        self.test_dataset = self.site_dataset_dict[self.test_domain]['test']
        self.test_dataloader = self.site_dataloader_dict[self.test_domain]['test']

    @staticmethod
    def SingleSite(domain_name, batch_size=128, seed=0):
        dataset_dict = {
            'train': Digits_SingleDomain(domain_name=domain_name, split='train', train_transform=transform_train,
                                           seed=seed).dataset,
            'val': Digits_SingleDomain(domain_name=domain_name, split='val', seed=seed).dataset,
            'test': Digits_SingleDomain(domain_name=domain_name, split='test', seed=seed).dataset,
        }
        dataloader_dict = GetDataLoaderDict(dataset_dict, batch_size)
        return dataloader_dict, dataset_dict

    def GetData(self):
        return self.site_dataloader_dict, self.site_dataset_dict
