#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
from sampling import svhn_iid, svhn_noniid
import math
import numpy as np
import timeit


def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'svhn':
        data_dir = '../data/svhn/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.SVHN(data_dir, split='train', download=True,
                                       transform=apply_transform)

        test_dataset = datasets.SVHN(data_dir, split='test', download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = svhn_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = svhn_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        if args.dataset == 'mnist':
            train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                           transform=apply_transform)

            test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                          transform=apply_transform)
        else:
            train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,
                                           transform=apply_transform)

            test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True,
                                          transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups


def average_weights(w, args):
    """
    Returns the average of the weights.
    """
    if args.discrete == 0:
        start = timeit.default_timer()
        w_avg = copy.deepcopy(w[0])
        for key in w_avg.keys():
            for i in range(1, len(w)):
                w_avg[key] += w[i][key]
                w_avg[key] = torch.div(w_avg[key], len(w))
        stop = timeit.default_timer()
        print('Aggregation Time: ', stop - start)
        return w_avg
    else:
        start = timeit.default_timer()
        max_value_global, min_value_global = -100000, 100000
        w_avg = copy.deepcopy(w[0])
        for k in w_avg.keys():
            w_avg[k], max_value_return, min_value_return = discrete_mechanism(w, k, args)
            max_value_global = max(max_value_global, max_value_return)
            min_value_global = min(min_value_global, min_value_return)
        stop = timeit.default_timer()
        print('Aggregation Time: ', stop - start)
        return w_avg


def discrete_mechanism(w, k, args):
    max_value_return, min_value_return = -100000, 100000

    dimension = get_dimension(w[0][k])
    matrix = w[0][k].view(1, dimension)
    for i in range(1, len(w)):
        matrix = torch.cat((matrix, w[i][k].view(1, dimension)), 0)
    matrix = matrix.transpose(0, 1)
    random_matrix = torch.rand(dimension, len(w)).to('cuda:' + str(args.gpu))

    i = 1
    max_value = torch.max(matrix) * i # 0.075 #
    min_value = torch.min(matrix) * i # -0.075 #
    mean_value = (max_value + min_value) / 2.0
    distance = (max_value - min_value) / 2.0
    scale_matrix = (matrix - mean_value) / distance
    scale_matrix[scale_matrix != scale_matrix] = 1
    scale_matrix[scale_matrix > 1] = 1
    scale_matrix[scale_matrix < -1] = -1

    scale_matrix = 1.0 * (scale_matrix + 1) / 2
    scale_matrix = scale_matrix - random_matrix
    scale_matrix[scale_matrix >= 0] = 1.0
    scale_matrix[scale_matrix < 0] = -1.0
    ones = torch.ones(scale_matrix.shape[1]).to('cuda:' + str(args.gpu))
    #Avalibility Attack
    for i in range(args.apa):
        ones[i] = -1.0
    scale_matrix = scale_matrix * ones
    matrix.data = 1.0 * scale_matrix * distance + mean_value
    mean_matrix = torch.mean(matrix, dim=1)
    return_matrix = mean_matrix.view(w[0][k].shape)
    return return_matrix, max_value_return, min_value_return


def get_dimension(w):
    dimention = 1
    for item in w.shape:
        dimention *= item
    return dimention


def exp_details(args):
    print('\nExperimental details:')
    print(f'    Model     : {args.model}')
    print(f'    Discrete Mechanism      : {args.discrete}')
    print(f'    Num Users     : {args.num_users}')
    print(f'    Num Avalibility Posioning Attackers     : {args.apa}')
    print(f'    Num Itergreity Backdoor Attackers     : {args.iba}')
    print(f'    Optimizer : {args.optimizer}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')

    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.local_ep}\n')
    return
