import torchvision.transforms as transforms

from utils.datasets import DigitsDataset


def compose_transforms(trns, image_norm):
    if image_norm == '0.5':
        return transforms.Compose(trns + [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    elif image_norm == 'torch':
        return transforms.Compose(trns + [transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                               (0.2023, 0.1994, 0.2010))])
    elif image_norm == 'torch-resnet':
        return transforms.Compose(trns + [transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])])
    elif image_norm == 'none':
        return transforms.Compose(trns)
    else:
        raise ValueError(f"Invalid image_norm: {image_norm}")


def get_central_data(name: str, domains: list, percent=1., image_norm='none',
                     disable_image_norm_error=False):
    if image_norm != 'none' and not disable_image_norm_error:
        raise RuntimeError(f"This is a hard warning. Use image_norm != none will make the PGD"
                           f" attack invalid since PGD will clip the image into [0,1] range. "
                           f"Think before you choose {image_norm} image_norm.")
    if percent != 1. and name.lower() != 'digits':
        raise RuntimeError(f"percent={percent} should not be used in get_central_data."
                           f" Pass it to make_fed_data instead.")
    if name.lower() == 'digits':
        if image_norm == 'default':
            image_norm = '0.5'
        for domain in domains:
            if domain not in DigitsDataset.all_domains:
                raise ValueError(f"Invalid domain: {domain}")
        # Prepare data
        trns = {
            'MNIST': [
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
            ],
            'SVHN': [
                transforms.Resize([28,28]),
                transforms.ToTensor(),
            ],
            'USPS': [
                transforms.Resize([28,28]),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
            ],
            'SynthDigits': [
                transforms.Resize([28,28]),
                transforms.ToTensor(),
            ],
            'MNIST_M': [
                transforms.ToTensor(),
            ],
        }

        train_sets = [DigitsDataset(domain,
                                    percent=percent, train=True,
                                    transform=compose_transforms(trns[domain], image_norm))
                      for domain in domains]
        test_sets = [DigitsDataset(domain,
                                   train=False,
                                   transform=compose_transforms(trns[domain], image_norm))
                     for domain in domains]
    elif name.lower() in ('domainnet', 'domainnetf'):
        transform_train = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation((-30, 30)),
            transforms.ToTensor(),
        ])

        transform_test = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
        ])

        train_sets = [
            DomainNetDataset(domain, transform=transform_train,
                             full_set=name.lower()=='domainnetf')
            for domain in domains
        ]
        test_sets = [
            DomainNetDataset(domain, transform=transform_test, train=False,
                             full_set=name.lower()=='domainnetf')
            for domain in domains
        ]
    elif name.lower() == 'cifar10':
        if image_norm == 'default':
            image_norm = 'torch'
        for domain in domains:
            if domain not in CifarDataset.all_domains:
                raise ValueError(f"Invalid domain: {domain}")
        trn_train = [transforms.RandomCrop(32, padding=4),
                     transforms.RandomHorizontalFlip(),
                     transforms.ToTensor()]
        trn_test = [transforms.ToTensor()]

        train_sets = [CifarDataset(domain, train=True,
                                   transform=compose_transforms(trn_train, image_norm))
                      for domain in domains]
        test_sets = [CifarDataset(domain, train=False,
                                  transform=compose_transforms(trn_test, image_norm))
                     for domain in domains]
    else:
        raise NotImplementedError(f"name: {name}")
    return train_sets, test_sets


if __name__ == '__main__':
    train_sets, _ = get_central_data('Digits', DigitsDataset.all_domains)
    for domain, ts in zip(DigitsDataset.all_domains, train_sets):
        print(f"{domain}: {len(ts)}")
