#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid,mnist_noniid_unequal,nlp_iid,nlp_noniid_new
from sampling import cifar_iid, cifar_noniid
from sampling import mnist_noniid_degree4,cifar_noniid_degree4,mnist_forget,fmnist_forget,cifar_forget
import numpy as np
from ibp_torch import torch_sample_BernConcrete



import pdb

np.random.seed(1)
torch.manual_seed(1)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(1)

def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        transform_train = transforms.Compose([
            transforms.RandomCrop(32,padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=transform_test)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'sstt':

        train_data_, test_data_ = smt_dataset(train=True, test=True)

        train_data = list(filter(lambda ex: ex['label'] != 'neutral', train_data_))
        test_data = list(filter(lambda ex: ex['label'] != 'neutral', test_data_))
        np.random.shuffle(train_data)
        np.random.shuffle(test_data)

        print("You got %d training data, and %d testing data" %(len(train_data), len(test_data)))



        train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['label']), train_data)))
        test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['label']), test_data)))

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], train_texts))
        test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], test_texts))

        train_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, train_tokens))
        test_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, test_tokens))

        train_tokens_ids = pad_sequences(train_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")
        test_tokens_ids = pad_sequences(test_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")

        train_y = np.array(train_labels) == 'positive'
        test_y = np.array(test_labels) == 'positive'

        train_masks =[[float(i > 0) for i in ii] for ii in train_tokens_ids]
        test_masks =[[float(i > 0) for i in ii] for ii in test_tokens_ids]



        train_tokens_tensor = torch.tensor(train_tokens_ids)
        train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

        test_tokens_tensor = torch.tensor(test_tokens_ids)
        test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

        train_masks_tensor = torch.tensor(train_masks)
        test_masks_tensor = torch.tensor(test_masks)

        train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
        test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)



        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = nlp_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = sst_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = nlp_noniid_new(train_dataset, args.num_users, args.Lam, args.num_chunk)


    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))])

            train_dataset = datasets.MNIST(data_dir, train=True, download=True,transform=apply_transform)
            test_dataset = datasets.MNIST(data_dir, train=False, download=True,transform=apply_transform)

        else:
            data_dir = '../data/fmnist/'
            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))])

            train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,transform=apply_transform)
            test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True,transform=apply_transform)

        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)



    else:
        print("Please specify your dataset!!")
        raise NotImplementedError()

    return train_dataset, test_dataset, user_groups




def get_dataset_fair(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        transform_train = transforms.Compose([
            transforms.RandomCrop(32,padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=transform_test)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_forget(train_dataset, args.num_users, 2)


    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))])


            train_dataset = datasets.MNIST(data_dir, train=True, download=True,transform=apply_transform)
            test_dataset = datasets.MNIST(data_dir, train=False, download=True,transform=apply_transform)

        else:
            data_dir = '../data/fmnist/'
            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))])

            train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,transform=apply_transform)
            test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True,transform=apply_transform)


        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                if args.dataset == 'mnist':
                    user_groups = mnist_forget(train_dataset, args.num_users,2)
                elif args.dataset == 'fmnist':
                    user_groups = fmnist_forget(train_dataset, args.num_users,2)



    else:
        print("Please specify your dataset!!")
        raise NotImplementedError()

    return train_dataset, test_dataset, user_groups







def get_dataset_test(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        transform_train = transforms.Compose([
            transforms.RandomCrop(32,padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=transform_test)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid_degree4(train_dataset, args.num_users,4)

    elif args.dataset == 'sstt':

        train_data_, test_data_ = smt_dataset(train=True, test=True)

        train_data = list(filter(lambda ex: ex['label'] != 'neutral', train_data_))
        test_data = list(filter(lambda ex: ex['label'] != 'neutral', test_data_))
        np.random.shuffle(train_data)
        np.random.shuffle(test_data)

        print("You got %d training data, and %d testing data" %(len(train_data), len(test_data)))



        train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['label']), train_data)))
        test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['label']), test_data)))

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], train_texts))
        test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], test_texts))

        train_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, train_tokens))
        test_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, test_tokens))

        train_tokens_ids = pad_sequences(train_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")
        test_tokens_ids = pad_sequences(test_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")

        train_y = np.array(train_labels) == 'positive'
        test_y = np.array(test_labels) == 'positive'

        train_masks =[[float(i > 0) for i in ii] for ii in train_tokens_ids]
        test_masks =[[float(i > 0) for i in ii] for ii in test_tokens_ids]



        train_tokens_tensor = torch.tensor(train_tokens_ids)
        train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

        test_tokens_tensor = torch.tensor(test_tokens_ids)
        test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

        train_masks_tensor = torch.tensor(train_masks)
        test_masks_tensor = torch.tensor(test_masks)

        train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
        test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)



        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = nlp_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = sst_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = nlp_noniid_new(train_dataset, args.num_users, args.Lam, args.num_chunk)


    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid_degree4(train_dataset, args.num_users, 4)



    else:
        print("Please specify your dataset!!")
        raise NotImplementedError()

    return train_dataset, test_dataset, user_groups



def get_dataset_3(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        #apply_transform = transforms.Compose(
        #    [transforms.ToTensor(),
        #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32,padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=transform_test)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid_degree4(train_dataset, args.num_users,3)

    elif args.dataset == 'sstt':

        train_data_, test_data_ = smt_dataset(train=True, test=True)

        train_data = list(filter(lambda ex: ex['label'] != 'neutral', train_data_))
        test_data = list(filter(lambda ex: ex['label'] != 'neutral', test_data_))
        np.random.shuffle(train_data)
        np.random.shuffle(test_data)

        print("You got %d training data, and %d testing data" %(len(train_data), len(test_data)))



        train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['label']), train_data)))
        test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['label']), test_data)))

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], train_texts))
        test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], test_texts))

        train_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, train_tokens))
        test_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, test_tokens))

        train_tokens_ids = pad_sequences(train_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")
        test_tokens_ids = pad_sequences(test_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")

        train_y = np.array(train_labels) == 'positive'
        test_y = np.array(test_labels) == 'positive'

        train_masks =[[float(i > 0) for i in ii] for ii in train_tokens_ids]
        test_masks =[[float(i > 0) for i in ii] for ii in test_tokens_ids]



        train_tokens_tensor = torch.tensor(train_tokens_ids)
        train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

        test_tokens_tensor = torch.tensor(test_tokens_ids)
        test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

        train_masks_tensor = torch.tensor(train_masks)
        test_masks_tensor = torch.tensor(test_masks)

        train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
        test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)



        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = nlp_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = sst_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = nlp_noniid_new(train_dataset, args.num_users, args.Lam, args.num_chunk)


    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid_degree4(train_dataset, args.num_users, 4)



    else:
        print("Please specify your dataset!!")
        raise NotImplementedError()

    return train_dataset, test_dataset, user_groups




def get_dataset_5(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        #apply_transform = transforms.Compose(
        #    [transforms.ToTensor(),
        #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32,padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=transform_test)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid_degree4(train_dataset, args.num_users,5)

    elif args.dataset == 'sstt':

        train_data_, test_data_ = smt_dataset(train=True, test=True)

        train_data = list(filter(lambda ex: ex['label'] != 'neutral', train_data_))
        test_data = list(filter(lambda ex: ex['label'] != 'neutral', test_data_))
        np.random.shuffle(train_data)
        np.random.shuffle(test_data)

        print("You got %d training data, and %d testing data" %(len(train_data), len(test_data)))



        train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['label']), train_data)))
        test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['label']), test_data)))

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], train_texts))
        test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], test_texts))

        train_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, train_tokens))
        test_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, test_tokens))

        train_tokens_ids = pad_sequences(train_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")
        test_tokens_ids = pad_sequences(test_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")

        train_y = np.array(train_labels) == 'positive'
        test_y = np.array(test_labels) == 'positive'

        train_masks =[[float(i > 0) for i in ii] for ii in train_tokens_ids]
        test_masks =[[float(i > 0) for i in ii] for ii in test_tokens_ids]



        train_tokens_tensor = torch.tensor(train_tokens_ids)
        train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

        test_tokens_tensor = torch.tensor(test_tokens_ids)
        test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

        train_masks_tensor = torch.tensor(train_masks)
        test_masks_tensor = torch.tensor(test_masks)

        train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
        test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)



        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = nlp_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = sst_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = nlp_noniid_new(train_dataset, args.num_users, args.Lam, args.num_chunk)


    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid_degree4(train_dataset, args.num_users, 4)



    else:
        print("Please specify your dataset!!")
        raise NotImplementedError()

    return train_dataset, test_dataset, user_groups



def get_dataset_6(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        #apply_transform = transforms.Compose(
        #    [transforms.ToTensor(),
        #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32,padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=transform_test)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid_degree4(train_dataset, args.num_users,6)

    elif args.dataset == 'sstt':

        train_data_, test_data_ = smt_dataset(train=True, test=True)

        train_data = list(filter(lambda ex: ex['label'] != 'neutral', train_data_))
        test_data = list(filter(lambda ex: ex['label'] != 'neutral', test_data_))
        np.random.shuffle(train_data)
        np.random.shuffle(test_data)

        print("You got %d training data, and %d testing data" %(len(train_data), len(test_data)))



        train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['label']), train_data)))
        test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['label']), test_data)))

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], train_texts))
        test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:511], test_texts))

        train_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, train_tokens))
        test_tokens_ids = list(map(tokenizer.convert_tokens_to_ids, test_tokens))

        train_tokens_ids = pad_sequences(train_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")
        test_tokens_ids = pad_sequences(test_tokens_ids, maxlen=512, truncating="post", padding="post", dtype="int")

        train_y = np.array(train_labels) == 'positive'
        test_y = np.array(test_labels) == 'positive'

        train_masks =[[float(i > 0) for i in ii] for ii in train_tokens_ids]
        test_masks =[[float(i > 0) for i in ii] for ii in test_tokens_ids]



        train_tokens_tensor = torch.tensor(train_tokens_ids)
        train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

        test_tokens_tensor = torch.tensor(test_tokens_ids)
        test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

        train_masks_tensor = torch.tensor(train_masks)
        test_masks_tensor = torch.tensor(test_masks)

        train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
        test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)



        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = nlp_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = sst_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = nlp_noniid_new(train_dataset, args.num_users, args.Lam, args.num_chunk)


    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid_degree4(train_dataset, args.num_users, 4)



    else:
        print("Please specify your dataset!!")
        raise NotImplementedError()

    return train_dataset, test_dataset, user_groups











def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

def average_weights_online(w_avg, w_new, nth):
    """
    Returns the average of the weights online.
    follows the equation avg_t = avg_t_1 + (a_t - avg_t_1) / t
    """
    for key in w_avg.keys():
        w_avg[key] = w_avg[key] + (w_new[key] - w_avg[key]) / nth
    return w_avg

def weights2cpu(we):
    """
    Returns the cpu weights.
    """
    for key in we.keys():
        we[key] = we[key].cpu()
    return we

def weights2gpu(we):
    """ Returns the cpu weights.
    """
    for key in we.keys():
        we[key] = we[key].cuda()
    return we

def replace_with_local_rp(local_model_weights, rts=None, pts=None):
    weights = list()
    if rts is not None:
        for i, (wa, r, p, wb) in enumerate(local_model_weights):
            weights.append((wa, rts[i],p, wb))
    if pts is not None:
        for i, (wa, r, p, wb) in enumerate(local_model_weights):
            weights.append((wa, r,pts[i], wb))

    if rts is None and pts is None:# prepare for the global model or new clients
        for i, (wa, r, p, wb) in enumerate(local_model_weights):
            weights.append((wa, r,0.5 * torch.ones_like(p), wb))
    return weights


def replace_with_local_r_only(local_model_weights, rts=None):
    weights = list()
    if rts is not None:
        for i, (wa, r, wb) in enumerate(local_model_weights):
            weights.append((wa, rts[i],wb))
    return weights





def update_client_info(weights, model_arch):
    # clients_info is dict:  {client_id: [rts, pts]}
    # idx: client id
    # weights: mdoel state_dict()
    # model_arch: (dim_in, dim_hidden,dim_out)
    current_client_info = dict()
    current_client_info['pts'] = list()
    current_client_info['rts'] = list()
    for i in range(len(model_arch)):
        current_client_info['rts'].append(weights['mlp_{}_r'.format(i)])
        current_client_info['pts'].append(weights['mlp_{}_p'.format(i)])
    return current_client_info


def update_client_info_ponly(weights, model_arch):
    # clients_info is dict:  {client_id: [rts, pts]}
    # idx: client id
    # weights: mdoel state_dict()
    # model_arch: (dim_in, dim_hidden,dim_out)
    current_client_info = dict()
    current_client_info['pts'] = list()
    for i in range(len(model_arch)):
        current_client_info['pts'].append(weights['mlp_{}_p'.format(i)])
    return current_client_info



def update_client_bayes_info(weights, model_arch):
    # clients_info is dict:  {client_id: [rts, pts]}
    # idx: client id
    # weights: mdoel state_dict()
    # model_arch: (dim_in, dim_hidden,dim_out)
    current_client_info = dict()
    current_client_info['Al'] = list()
    current_client_info['Bl'] = list()
    current_client_info['pi'] = list()
    for i in range(len(model_arch)):
        current_client_info['Al'].append(weights['A{}'.format(i)])
        current_client_info['Bl'].append(weights['B{}'.format(i)])
        current_client_info['pi'].append(weights['mlp_{}_p'.format(i)])
    return current_client_info


def update_client_info_general(weights, model_arch):
    # clients_info is dict:  {client_id: [rts, pts]}
    # idx: client id
    # weights: mdoel state_dict()
    # model_arch: (dim_in, dim_hidden,dim_out)
    current_client_info = dict()
    current_client_info['pts'] = list()
    current_client_info['rts'] = list()
    for i in range(len(model_arch)):
        key_name = model_arch[i][0] + '_{}_r'.format(i)
        current_client_info['rts'].append(weights[key_name])

        key_name = model_arch[i][0] + '_{}_p'.format(i)
        current_client_info['pts'].append(weights[key_name])
    return current_client_info


def update_client_bayes_info_general(weights, model_arch):
    # clients_info is dict:  {client_id: [rts, pts]}
    # idx: client id
    # weights: mdoel state_dict()
    # model_arch: (dim_in, dim_hidden,dim_out)
    current_client_info = dict()
    current_client_info['Al'] = list()
    current_client_info['Bl'] = list()
    current_client_info['pi'] = list()
    for i in range(len(model_arch)):
        current_client_info['Al'].append(weights['A{}'.format(i)])
        current_client_info['Bl'].append(weights['B{}'.format(i)])

        key_name = model_arch[i][0] + '_{}_p'.format(i)
        current_client_info['pi'].append(weights[key_name])
    return current_client_info


def weights_with_average_rp(local_model_weights, clients_info):

    weights = list()
    for i, (wa, r, p, wb) in enumerate(local_model_weights):
        avg_r = sum([clients_info[idx]['rts'][i] for idx in clients_info.keys()]) /(len(clients_info))
        avg_p = sum([clients_info[idx]['pts'][i] for idx in clients_info.keys()]) /(len(clients_info))
        weights.append((wa, avg_r, avg_p, wb))
    return weights


def average_dl(local_model_weights, clients_info, lambda_post):
    dls = list()
    sample_r = clients_info[list(clients_info.keys())[0]]['rts'][0]
    for i, (wa, r, p, wb) in enumerate(local_model_weights):
        current_zl = torch.squeeze(torch.zeros_like(sample_r, dtype=sample_r.dtype))
        for idx in clients_info:
            real_l = clients_info[idx]['rts'][i]
            pi_post = clients_info[idx]['pts'][i]
            Y_post, Binary_l = torch_sample_BernConcrete(pi_post,lambda_post)
            current_zl += torch.squeeze(Binary_l * real_l)
        dls.append(current_zl / (len(clients_info)))
    return dls

def average_rp(local_model_weights, clients_info):
    rlist = list()
    plist = list()
    for i, (wa, r, p, wb) in enumerate(local_model_weights):

        avg_r = sum([clients_info[idx]['rts'][i] for idx in clients_info.keys()]) /(len(clients_info))
        avg_p = sum([clients_info[idx]['pts'][i] for idx in clients_info.keys()]) /(len(clients_info))
        rlist.append(avg_r)
        plist.append(avg_p)

    return rlist,plist




def exp_details(args):
    print('\nExperimental details:')
    print('    Model     : {}'.format(args.model))
    print('    Optimizer : {}'.format(args.optimizer))
    print('    Learning  : {}'.format(args.lr))
    print('    Global Rounds   : {}\n'.format(args.epochs))

    print('    Dataset     : {}'.format(args.dataset))
    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print('    Fraction of users  : {}'.format(args.frac))
    print('    Local Batch size   : {}'.format(args.local_bs))
    print('    Local Epochs       : {}\n'.format(args.local_ep))
    return

