#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import os
import json
import pickle
import copy
import numpy as np
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
from torch.utils.data import TensorDataset

def split_backup(dict_users, frac = 0.8):
    dict_back = {}
    dict_beta = {}
    for i in dict_users.keys():
        dict_users[i] = list(dict_users[i])
        end = max(int(frac * len(dict_users[i])), 1)
        dict_back[i] = dict_users[i][end:]
        dict_users[i] = dict_users[i][:end]
        dict_beta[i] = 0
    return dict_users, dict_back, dict_beta

def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    print(args.dataset)

    if 'sent140' in args.dataset or 'shakespeare' in args.dataset:
        train_path = os.path.join(os.path.dirname(os.getcwd()), "Upcycled_objective/data/" + args.dataset + "/data/train/mytrain.json")
        test_path = os.path.join(os.path.dirname(os.getcwd()), "Upcycled_objective/data/" + args.dataset + "/data/test/mytest.json")
        with open(train_path, 'r') as inf:
            dataset = json.load(inf)
        train_datasets = {}
        group_ws = []
        for i in range(len(dataset)):
            x, y = dataset[i]
            train_datasets[int(i)] = TensorDataset(torch.LongTensor(x), torch.LongTensor(y))
            group_ws.append(len(y))
        total = sum(group_ws)
        group_ws = [item / total for item in group_ws]
        group_ws = dict(zip(range(len(group_ws)), group_ws))
        with open(test_path, 'rb') as inf:
            dataset = json.load(inf)
        x, y = dataset[0], dataset[1]
        test_dataset = TensorDataset(torch.LongTensor(x), torch.LongTensor(y))



        return train_datasets, test_dataset, [], group_ws

    # if "synthetic" in args.dataset:
    if 1:
        print('load' + args.dataset)

        train_path = os.path.join(os.path.dirname(os.getcwd()), "Upcycled_objective/data/" + args.dataset + "/data/train/mytrain.pt")
        test_path = os.path.join(os.path.dirname(os.getcwd()), "Upcycled_objective/data/" + args.dataset + "/data/test/mytest.pt")
        user_group_path = os.path.join(os.path.dirname(os.getcwd()), "Upcycled_objective/data/" + args.dataset + "/data/train/user_groups.json")
        with open(user_group_path, 'rb') as inf:
            user_groups = json.load(inf)
        train_dataset = torch.load(train_path)
        test_dataset = torch.load(test_path)
    #
    # if args.dataset == 'femnist':
    #     data_dir = '../data/femnist/'
    #     apply_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.1307,), (0.3081,))])
    #
    #     train_dataset = datasets.MNIST(data_dir, train=True, download=True,
    #                                    transform=apply_transform)
    #
    #     test_dataset = datasets.MNIST(data_dir, train=False, download=True,
    #                                   transform=apply_transform)

    train_datasets = {}
    for key, value in user_groups.items():
        train_datasets[int(key)] = TensorDataset(*train_dataset[np.array(value)])

    return train_datasets, test_dataset, user_groups, compute_group_weights(user_groups)

def compute_group_weights(user_groups):
    weights = {}
    sum = 0
    for key, value in user_groups.items():
        sum += len(value)
        weights[int(key)] = len(value)
    for key in weights.keys():
        weights[key] /= sum
    return weights

def average_weights(ls, global_weights):
    """
    Returns the average of the weights.
    """
    sum_ws = 0
    w_avg = copy.deepcopy(ls[0][0])
    for key in w_avg.keys():
        if 'num_batches_tracked' in key:
            continue
        w_avg[key] *= ls[0][1]
    sum_ws += ls[0][1]
    for i in range(1, len(ls)):
        for key in w_avg.keys():
            if 'num_batches_tracked' in key:
                continue
            w_avg[key] += (ls[i][0][key] * ls[i][1])
        sum_ws += ls[i][1]
    for key in w_avg.keys():
        if 'num_batches_tracked' in key:
            continue
        w_avg[key] /= sum_ws

    return w_avg


def exp_details(args):
    print('\nExperimental details:')
    print(f'    Dataset     : {args.dataset}')
    print(f'    Algorithm     : {args.algorithm}')
    print(f'    Model     : {args.model}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.local_ep}\n')
    print(f'    Mu     : {args.prox_param}')
    print(f'    Upcycled     : {args.upcycled_param}')
    print(f'    Straggler     : {args.straggler}')
    print(f'    delta     : {args.delta}')
    print(f'    sigma     : {args.sigma}')
    print(f'    clip     : {args.clip}')
    print(f'    alpha     : {args.alpha}')
    return
