#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
# Pytorch version: 1.1.0 or above


import os
import logging
import copy
import random
import numpy as np
import math
import torch
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms

from data_sampling_cc import U_data_random, A_data_random, U_data_if_iid_equal, A_data_if_iid_equal
from models_fed.AE_cc import Autoencoder, ConvAutoencoder
from models_fed.UNet_cc import UNet
from models_fed.MLP_cc import MLP
from models_fed.Alexnet_cc import AlexNet
from models_fed.small_inception_cc import Small_Inception
from models_fed.vgg_cc import vgg11, vgg16
from models_fed.resnet_cc import resnet18, resnet34


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
        self.avg = 0
        self.sum = 0
        self.val = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class DatasetSplit(torch.utils.data.Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class."""
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        input_init, target = self.dataset[self.idxs[item]]
        return torch.tensor(input_init), torch.tensor(target)


def get_data_loaders(args):
    data_path = os.path.join(args.script_path, 'data')
    if args.data == 'cifar10':
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        apply_transform = transforms.Compose([transforms.ToTensor(), normalize])
        train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=apply_transform)
        test_dataset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=apply_transform)
    elif args.data == 'cifar100':
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        apply_transform = transforms.Compose([transforms.ToTensor(), normalize])
        train_dataset = datasets.CIFAR100(root=data_path, train=True, download=True, transform=apply_transform)
        test_dataset = datasets.CIFAR100(root=data_path, train=False, download=True, transform=apply_transform)
    elif args.data == 'MNIST':
        normalize = transforms.Normalize((0.1307,), (0.3081,))
        apply_transform = transforms.Compose([transforms.ToTensor(), normalize])
        train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=apply_transform)
        test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=apply_transform)
    else:
        raise Exception('Unsupported dataset: {0}'.format(args.data))

    #kwargs = {'num_workers': 1, 'pin_memory': True}
    #test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.U_batch_size, shuffle=False)
    #sampler=torch.utils.data.sampler.SubsetRandomSampler(list(range(args.test_size)))

    if args.random == 1:
        U_user_groups, U_user_data_size = U_data_random(args, train_dataset)
        A_user_groups, A_user_data_size = A_data_random(args, train_dataset)
    else:
        U_user_groups, U_user_data_size = U_data_if_iid_equal(args, train_dataset)
        A_user_groups, A_user_data_size = A_data_if_iid_equal(args, train_dataset)
    return train_dataset, test_loader, U_user_groups, A_user_groups, U_user_data_size, A_user_data_size


def fl_get_train_data_loaders(args, train_dataset, idxs, flag):
    """Returns train dataloader for the FL train dataset and user indexes."""
    #kwargs = {'num_workers': 1, 'pin_memory': True}
    #train_loader = torch.utils.data.DataLoader(DatasetSplit(train_dataset, idxs), batch_size=args.batch_size, shuffle=True, **kwargs)
    dataset = DatasetSplit(train_dataset, idxs)
    if flag=='U_training':
        if args.learn_type == 'USGD':
            U_local_steps = args.local_steps_init
        else:
            U_local_steps = round(args.local_steps_init*args.rou)
        sampler_replacement = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=args.U_batch_size*U_local_steps)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.U_batch_size, sampler=sampler_replacement)
    elif flag=='A_training':
        A_local_steps = args.local_steps_init - round(args.local_steps_init*args.rou)
        sampler_replacement = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=args.A_batch_size*A_local_steps)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.A_batch_size, sampler=sampler_replacement)
    elif flag=='UA_training':
        UA_local_steps = args.local_steps_init
        sampler_replacement = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=args.U_batch_size*UA_local_steps)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.U_batch_size, sampler=sampler_replacement)
    else:
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=len(idxs), shuffle=True)
    return train_loader


def average_weights(local_weights, user_data_size_lst):
    """Returns the average of the weights."""
    total_data_size = sum(user_data_size_lst)
    w_avg = copy.deepcopy(local_weights[0])
    for key in w_avg.keys():
        w_avg[key] = (user_data_size_lst[0]/total_data_size)*w_avg[key]
        for i in range(1, len(local_weights)):
            w_avg[key] += (user_data_size_lst[i]/total_data_size)*local_weights[i][key]
        #w_avg[key] = torch.div(w_avg[key], len(local_weights))
    return w_avg


def get_prox_term(weight_1, weight_2, lr, mu):
    diff = copy.deepcopy(weight_1)
    for key in diff.keys():
        diff[key] = lr*mu*(weight_2[key]-weight_1[key])
    return diff


def get_prox_term_2(model_1, model_2, lr, mu):
    diff = list()
    para_epoch_1 = list(model_1.parameters())
    para_epoch_2 = list(model_2.parameters())
    for i in range(len(para_epoch_1)):
        diff.append(lr*mu*(para_epoch_2[i]-para_epoch_1[i]))
    return diff


def average_np_list(np_list, user_data_size_lst):
    total_data_size = sum(user_data_size_lst)
    res = 0
    for i in range(np_list.size):
        res += (user_data_size_lst[i]/total_data_size)*np_list[i]
    return res.item()


def get_grad_norm_square(model):
    para_epoch = list(model.parameters())
    l2_norm_square = 0
    for i in range(len(para_epoch)):
        temp = para_epoch[i].grad.view(1, -1).cpu().detach().numpy()
        temp = np.sum(np.square(temp))
        l2_norm_square += temp
    return l2_norm_square


def get_weight_norm_square(model):
    para_epoch = list(model.parameters())
    for i in range(len(para_epoch)):
        para_epoch[i] = para_epoch[i].view(1, -1).cpu().detach().numpy()
        para_epoch[i] = np.sum(np.square(para_epoch[i]))
    return np.sum(para_epoch)


def get_weight_norm_square_2(weight):
    l2_norm_square = 0
    for key in weight.keys():
        l2_norm_square += (torch.norm(weight[key], p=2)).pow(2)
    return l2_norm_square.cpu().detach().numpy()


def get_weight_norm_square_3(lst):
    l2_norm_square = 0
    for i in range(len(lst)):
        temp = lst[i]
        temp = temp.view(1, -1).cpu().detach().numpy()
        temp = np.sum(np.square(temp))
        l2_norm_square += temp
    return l2_norm_square


def get_weight_l2_dist(model_1, model_2):
    para_1 = list(model_1.parameters())
    para_2 = list(model_2.parameters())
    dist = 0
    for i in range(len(para_1)):
        para_1[i] = para_1[i].view(1, -1).cpu().detach().numpy()
        para_2[i] = para_2[i].view(1, -1).cpu().detach().numpy()
        dist_np = para_1[i] - para_2[i]
        dist += np.sum(np.square(dist_np))
    return np.sqrt(dist)


def get_weight_l2_dist_tensor(model_1, model_2):
    para_1 = list(model_1.parameters())
    para_2 = list(model_2.parameters())
    dist_square = 0
    for i in range(len(para_1)):
        dist_square += (torch.dist(para_1[i], para_2[i], p=2)).pow(2)
    return torch.sqrt(dist_square)


def get_weight_norm_variance(w_local_lst, w_global):
    n = len(w_local_lst)
    if n==1:
        res = 0
    else:
        res = 2/(n-1)*sum(w_local_lst) - 2*n/(n-1)*w_global
    return res


def get_model(args):
    # create model
    if args.arch == 'ae':
        if args.data == 'cifar10' or args.data == 'cifar100':
            size = [3,32,32]       # input dim
        elif args.data == 'MNIST':
            size = [1,28,28]
        model = Autoencoder(args, size)
    elif args.arch == 'conv_ae':
        model = ConvAutoencoder(args)
    elif args.arch == 'unet':
        model = UNet(args)
    elif args.arch == 'mlp':
        n_units = [int(x) for x in args.mlp_spec.split('x')]  # hidden dims
        n_units.append(args.num_classes)  # output dim
        if args.data == 'cifar10' or args.data == 'cifar100':
            n_units.insert(0, 32*32*3)        # input dim
        elif args.data == 'MNIST':
            n_units.insert(0, 28*28*1)
        model = MLP(args, n_units)
    elif args.arch == 'alexnet':
        model = AlexNet(args)
    elif args.arch == 'inception':
        model = Small_Inception(args)
    elif args.arch == 'vgg11':
        model = vgg11(args)
    elif args.arch == 'vgg16':
        model = vgg16(args)
    elif args.arch == 'resnet18':
        model = resnet18(args)
    elif args.arch == 'resnet34':
        model = resnet34(args)
    else:
        #exit('Error: unrecognized model')
        raise Exception('Unsupported model: {0}'.format(args.arch))
    if torch.cuda.is_available():
        model = model.cuda()
    return model


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate * (0.3 ** (epoch // args.adjust_lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def adjust_learning_rate_per_iter(args, optimizer, iteration):
    """Sets the learning rate to the initial LR decayed by sqrt(iter) every iteration where iter is the iteration number"""
    lr = args.learning_rate / (math.sqrt(iteration))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def accuracy_tar_prob(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, output_hard = output.topk(maxk, 1, True, True)
    _, target_hard = target.topk(maxk, 1, True, True)
    output_hard = output_hard.t()
    target_hard = target_hard.t()
    correct = output_hard.eq(target_hard)

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def accuracy_encoder_decoder(input_data, output_data):
    """Use Euclidean distance to compute the accuracy"""
    input_data = torch.sigmoid(input_data)
    output_data = torch.sigmoid(output_data)
    input_data = input_data.cpu().detach().numpy()
    output_data = output_data.cpu().detach().numpy()
    dist = np.linalg.norm(input_data-output_data)
    input_int_reverse = 1 - np.round(input_data)
    dist_max = np.linalg.norm(input_data-input_int_reverse)
    res = 1 - (dist / dist_max)
    return res


def accuracy_regression(output_data, target_data):
    """Use Euclidean distance to compute the accuracy"""
    target_data = target_data.cpu().detach().numpy()
    output_data = output_data.cpu().detach().numpy()
    dist = np.linalg.norm(target_data-output_data)
    dist_max = np.linalg.norm(target_data)
    res = 1 - (dist / dist_max)
    return res


def setup_logging(args, exp_dir):
    import datetime
    log_fn = ''
    if args.command == 'train':
        log_fn = os.path.join(exp_dir, "LOG_train.{0}.txt".format(datetime.date.today().strftime("%y%m%d")))
    else:
        log_fn = os.path.join(exp_dir, "LOG_test.{0}.txt".format(datetime.date.today().strftime("%y%m%d")))
    logging.basicConfig(filename=log_fn, filemode='w', level=logging.DEBUG)
    # also log into console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    print('Logging into %s...' % exp_dir)


def seed_everything(seed=1234):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
