
from torch import nn


import os,sys
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch


from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, svhn_iid
import utils.data_utils


def load_data(args):

    if args.dataset == 'mnist':
        # trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        trans_mnist = transforms.Compose([transforms.ToTensor()])
        dataset_train = datasets.MNIST('./data/MNIST/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('./data/MNIST/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    elif args.dataset == 'fmnist':
        trans_fmnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.FashionMNIST('../data/fmnist', train=True, download=True, transform=trans_fmnist)
        dataset_test = datasets.FashionMNIST('../data/fmnist', train=False, download=True, transform=trans_fmnist)
        if args.iid:
            dict_users = fmnist_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in FMNIST')
    elif args.dataset == 'svhn':
        trans_svhn = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.SVHN('../data/svhn', split='train', download=True, transform=trans_svhn)
        dataset_test = datasets.SVHN('../data/svhn', split='test', download=True, transform=trans_svhn)
        # dataset_extra = datasets.SVHN('../data/svhn', split='extra', transform=trans_svhn,
        #                        target_transform=None, download=True)
        if args.iid:
            dict_users = svhn_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in SVHN')
    elif args.dataset == 'domain_digits':

        transform_mnist = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            # transforms.Resize([40, 40]),  # vgg16
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_svhn = transforms.Compose([
            # transforms.Resize([40, 40]), # vgg16
            transforms.Resize([28, 28]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_usps = transforms.Compose([
            # transforms.Resize([40, 40]),  # vgg16
            transforms.Resize([28, 28]),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_synth = transforms.Compose([
            # transforms.Resize([40, 40]),  # vgg16
            transforms.Resize([28, 28]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_mnistm = transforms.Compose([
            # transforms.Resize([40, 40]),  # vgg16
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # MNIST
        # mnist_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1, percent=args.percent,
        #                                           train=True, transform=transform_mnist) #org

        mnist_trainset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1, percent=args.percent,
                                                  train=True, transform=transform_mnist,
                                                  inject_backdoor=args.backdoor,
                                                  load_backdoor=False, args=args, dataset='MNIST',
                                                  backdoortest=False)  # backdoor

        mnist_trainset_backdoor_test = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1,
                                                                percent=args.percent,
                                                                train=True, transform=transform_mnist,
                                                                inject_backdoor=args.backdoor,
                                                                load_backdoor=True, args=args, dataset='MNIST',
                                                                backdoortest=True)  # backdoor train data test

        mnist_testset = data_utils.DigitsDataset(data_path="./data/digitdata/MNIST", channels=1, percent=args.percent,
                                                 train=False, transform=transform_mnist)

        # SVHN
        svhn_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/SVHN', channels=3, percent=args.percent,
                                                 train=True,
                                                 transform=transform_svhn)
        svhn_testset = data_utils.DigitsDataset(data_path='./data/digitdata/SVHN', channels=3, percent=args.percent,
                                                train=False,
                                                transform=transform_svhn)

        # USPS
        usps_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/USPS', channels=1, percent=args.percent,
                                                 train=True,
                                                 transform=transform_usps)
        usps_testset = data_utils.DigitsDataset(data_path='./data/digitdata/USPS', channels=1, percent=args.percent,
                                                train=False,
                                                transform=transform_usps)

        # Synth Digits
        synth_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/SynthDigits/', channels=3,
                                                  percent=args.percent,
                                                  train=True, transform=transform_synth)
        synth_testset = data_utils.DigitsDataset(data_path='./data/digitdata/SynthDigits/', channels=3,
                                                 percent=args.percent,
                                                 train=False, transform=transform_synth)

        # MNIST_M
        mnistm_trainset = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=args.percent,
                                                   train=True, transform=transform_mnistm)
        mnistm_testset = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=args.percent,
                                                  train=False, transform=transform_mnistm)

        mnistm_trainset_2 = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=-1,
                                                     train=True, transform=transform_mnistm)
        mnistm_testset_2 = data_utils.DigitsDataset(data_path='./data/digitdata/MNIST_M/', channels=3, percent=-1,
                                                    train=False, transform=transform_mnistm)

        mnist_train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=args.local_bs, shuffle=True)
        mnist_test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=args.local_bs, shuffle=False)
        svhn_train_loader = torch.utils.data.DataLoader(svhn_trainset, batch_size=args.local_bs, shuffle=True)
        svhn_test_loader = torch.utils.data.DataLoader(svhn_testset, batch_size=args.local_bs, shuffle=False)
        usps_train_loader = torch.utils.data.DataLoader(usps_trainset, batch_size=args.local_bs, shuffle=True)
        usps_test_loader = torch.utils.data.DataLoader(usps_testset, batch_size=args.local_bs, shuffle=False)
        synth_train_loader = torch.utils.data.DataLoader(synth_trainset, batch_size=args.local_bs, shuffle=True)
        synth_test_loader = torch.utils.data.DataLoader(synth_testset, batch_size=args.local_bs, shuffle=False)
        mnistm_train_loader = torch.utils.data.DataLoader(mnistm_trainset, batch_size=args.local_bs, shuffle=True)
        mnistm_test_loader = torch.utils.data.DataLoader(mnistm_testset, batch_size=args.local_bs, shuffle=False)
        mnistm_train_loader_2 = torch.utils.data.DataLoader(mnistm_trainset_2, batch_size=args.local_bs, shuffle=True)
        mnistm_test_loader_2 = torch.utils.data.DataLoader(mnistm_testset_2, batch_size=args.local_bs, shuffle=False)

        train_loaders = [mnist_train_loader, svhn_train_loader, usps_train_loader, synth_train_loader,
                         mnistm_train_loader, mnistm_train_loader_2]
        test_loaders = [mnist_test_loader, svhn_test_loader, usps_test_loader, synth_test_loader, mnistm_test_loader,
                        mnistm_test_loader_2]
        if args.backdoor:
            backdoorloader = torch.utils.data.DataLoader(mnist_trainset_backdoor_test, batch_size=args.local_bs,
                                                  shuffle=False)

        return train_loaders,test_loaders,backdoorloader
    else:
        exit('Error: unrecognized dataset')

