################################################################################
# This is a Pytorch implementation of FedREP framework.
# This file contains the core parts of our code and is for peer-review only.
# The remaining parts will be released when this work is published.
################################################################################

import os
import shutil
import time
import argparse
import base64
import random
import secrets

import numpy
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.distributed as dist
import datetime

from secure_aggregation import *
from models import *
from aggrs import *
from attacks import *

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--epochs', default=160, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--momentum', default=0, type=float, metavar='N',
                    help='momentum on workers')
parser.add_argument('--lr', '--learning-rate', default=0.5, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--sp_rate', '--spars-rate', default=1, type=float,
                    metavar='LR', help='sparsification rate')

parser.add_argument('--aggr_rule', default='mean', type=str,
                    help='aggregation rule applied on server')
parser.add_argument('--byz_mode', default='noAtk', type=str,
                    help='attack on workers')
parser.add_argument('--byz_num', default=0, type=int,
                    help='number of Byzantine workers')

parser.add_argument('--FoE_epsilon', default=6.0, type=float, metavar='N',
                    help='hyper-parameter of FoE attack')
parser.add_argument('--tau', default=0.5, type=float, metavar='N',
                    help='clipping hyper-parameter of CC')
parser.add_argument('--robust_iters', default=5, type=int, metavar='N',
                    help='number of iterations in robust aggregation')
parser.add_argument('--interval', default=1, type=int, metavar='N',
                    help='global model updating interval')

parser.add_argument('--model', default='ResNet20', type=str,
                    help='machine learning model used for image classification')

parser.add_argument('--quant_level',default=1000,type=int,
                    help='tuning parameter that identifies the quantization level')

parser.add_argument('--group_size',default=4,type=int,
                    help='group size')

def coordinate(rank, world_size):
    args = parser.parse_args()
    robust_mode = args.aggr_rule
    byz_mode = args.byz_mode
    byz_num = args.byz_num
    momentum = args.momentum
    lr = args.lr
    interval = args.interval
    FoE_epsilon = args.FoE_epsilon
    model_str = args.model

    tau = args.tau
    robust_iters = args.robust_iters

    spars_rate = args.sp_rate

    quant_level = args.quant_level
    group_size = args.group_size
    t_threshold = int(args.group_size / 2) 
    group_num = int (world_size / group_size)

    if model_str == 'ResNet20':
        model = resnet20()
    elif model_str == 'AlexNet':
        model = alexnet()
    elif model_str == 'VGG19':
        model = vgg19_bn(num_classes=10)
    else:
        raise ValueError("Unknown model!")

    # print('Start coordinate  Total: %3d'%(world_size))
    model = model.cuda(0)
    model_flat = flatten_all(model)
    dim = len(flatten(model))

    dist.broadcast(model_flat, world_size)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    # Data loading code
    train_transform = transforms.Compose(
        [transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    val_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = datasets.CIFAR10(root='./data', train=True, download=False, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=25 * world_size, pin_memory=True, shuffle=False,
                                               num_workers=2)

    valset = datasets.CIFAR10(root='./data', train=False, download=False, transform=val_transform)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=100, pin_memory=True, shuffle=False, num_workers=2)

    # time_cost = 0
    datetime.datetime.now().strftime('%m%d%H%M')
    now_time2 = datetime.datetime.now() + datetime.timedelta(hours=8)
    if momentum == 0:
        output_path = datetime.datetime.strftime(now_time2, model_str + '_n' + str(world_size)
                                                 + '_newAlg' + str(spars_rate) + '_' + robust_mode + '_' + byz_mode
                                                 + '_byz_num' + str(byz_num) + '_lr' + str(lr) + '_' + str(group_size) +'_Spars_SecAgg' + '_%m%d%H%M%S')
    else:
        output_path = datetime.datetime.strftime(now_time2, model_str + 'M_n' + str(world_size)
                                                 + '_newAlg' + str(spars_rate) + '_' + robust_mode + '_' + byz_mode
                                                 + '_byz_num' + str(byz_num) + '_lr' + str(lr)+ '_' + str(group_size) +'_Spars_SecAgg' + '_%m%d%H%M%S')

    output_file = open(output_path + '.txt', 'w')
    print("--------now_time:" + datetime.datetime.strftime(now_time2, '%m%d_%H%M%S') + "--------")

    for epoch in range(args.epochs):

        print('server\'s epoch: ' + str(epoch))
        # adjust learning rate

        dist.barrier()

        t1 = time.time()
        inner_product = []
        inner_product_temp = []
        data_length = len(train_loader)

        for i in range(data_length):

            if (i + 1) % interval == 0 or i + 1 == data_length:

                # receive the indices and compute the union set
                indices_list = numpy.empty(0, dtype=int)

                temp = torch.LongTensor(int(dim // world_size * spars_rate)).zero_()

                for j in range(world_size):
                    dist.recv(temp, j)
                    indices_list = numpy.union1d(indices_list, temp)

                # broadcast the length of sparsified gradient
                temp_length = numpy.size(indices_list)
                dist.broadcast(torch.from_numpy(numpy.array([temp_length], dtype=int)), src=world_size)

                # broadcast the indices of sparsified gradient
                temp = torch.from_numpy(indices_list)
                dist.broadcast(temp, src=world_size)

                # communicate about gradients according to the indices
                time_start = time.time()

                # with secure agg:
                g_flat_list = []
                g_flat_list = secure_agg_coordinate(rank,world_size,group_num,temp_length,t_threshold, 1)

                time_end = time.time()
                # print('communication time:', time_end - time_start)
                delta = robust_aggregate(rule=robust_mode, vectors=g_flat_list, byz_num=byz_num,
                                         tau=tau, iterations=robust_iters)
                dist.broadcast(delta, world_size)

        # receive model for testing
        dist.barrier()
        time_cost = time.time() - t1
        model_flat.zero_()
        dist.recv(model_flat, src=world_size - 1)
        unflatten_all(model, model_flat)

        # test
        loss, _ = validate(train_loader, model, criterion)
        _, prec1 = validate(val_loader, model, criterion)
        output_file.write('%d %3f %3f %3f\n' %
                          (epoch, time_cost, loss, prec1))
        output_file.flush()
    output_file.close()

    print('training finished.')


def run(rank, world_size):
    args = parser.parse_args()
    lr = args.lr
    momentum = args.momentum
    byz_mode = args.byz_mode
    byz_num = args.byz_num
    interval = args.interval
    spars_rate = args.sp_rate
    model_str = args.model

    adjust = [80, 120]

    print('Start node: %d  Total: %3d' % (rank, world_size))
    # model = vgg19_bn(num_classes=10)

    if model_str == 'ResNet20':
        model = resnet20()
    elif model_str == 'AlexNet':
        model = alexnet()
    elif model_str == 'VGG19':
        model = vgg19_bn(num_classes=10)
    else:
        raise ValueError("Unknown model!")

    model = model.cuda()
    model_flat = flatten_all(model)
    
    dist.broadcast(model_flat, world_size)
    unflatten_all(model, model_flat)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # Data loading code
    train_transform = transforms.Compose(
        [transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = datasets.CIFAR10(root='./data', train=True, download=False, transform=train_transform)
    class_num = 10

    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=world_size, rank=rank)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=25, pin_memory=True, shuffle=False,
                                               num_workers=2, sampler=train_sampler)

    u_flat = torch.zeros_like(flatten(model))
    residual = torch.zeros_like(u_flat)

    for epoch in range(args.epochs):

        if epoch in adjust:
            lr = lr * 0.1
            residual.mul_(0.1)

        dist.barrier()

        # train for one epoch
        train_sampler.set_epoch(epoch)
        train(train_loader, model, criterion, lr, epoch, rank, world_size, interval,
              byz_mode, byz_num, class_num, momentum, u_flat, spars_rate, residual)
        dist.barrier()

        # send model to server for testing
        if rank == world_size - 1:
            model_flat = flatten_all(model)
            dist.send(model_flat, dst=world_size)


def train(train_loader, model, criterion, lr, epoch, rank, world_size, interval,
          byz_mode, byz_num, class_num, momentum, u_flat, spars_rate, residual):
    weight_decay = 0.0001
    args = parser.parse_args()
    t_threshold = int(args.group_size / 2) 

    # switch to train mode
    model.train()
    w_flat_last = flatten(model)
    w_flat = w_flat_last.clone().detach()

    dim = w_flat_last.size(0)
    length = int(dim // world_size * spars_rate)
    data_length = len(train_loader)

    for i, (instances, target) in enumerate(train_loader):
        # compute gradient and momentum
        g_flat = non_omniscient_attack_ic(byz_mode, byz_num, rank, world_size, model, w_flat, criterion,
                                          instances, target, weight_decay, class_num)
        u_flat.mul_(momentum)
        u_flat.add_(1 - momentum, g_flat)

        # local update model
        w_flat.add_(-lr, u_flat)
        unflatten(model, w_flat)

        if (i + 1) % interval == 0 or i + 1 == data_length:
            # update residual
            update_flat = w_flat.sub(w_flat_last)
            update_flat.add_(residual)
            residual = update_flat.clone().detach()

            # compute topK indices
            abs_value = residual.abs()
            rand_length = numpy.random.binomial(n=length, p=alpha)
                
            top_length = length - rand_length
            _, indices = torch.topk(abs_value, length, 0, largest=True, sorted=False)
            indices = indices[torch.randperm(length)[0:top_length]]
            
            remaining_indices = torch.rand_like(u_flat)
            remaining_indices.index_add_(0, indices, -torch.ones(top_length).cuda())
            _, rand_indices = torch.topk(remaining_indices, rand_length, 0, largest=True, sorted=False)
            
            if rand_length > 0:
                indices = torch.cat([indices, rand_indices])

            # communicate about indices
            dist.send(indices, dst=world_size)
            temp_length = torch.zeros(1, dtype=int)
            dist.broadcast(temp_length, src=world_size)

            indices = torch.zeros(temp_length, dtype=int).cuda()
            dist.broadcast(indices, src=world_size)

            # compute sparse gradient
            update_flat_spars = residual[indices]

            # update residuals

            residual[indices] = 0

            # communicate about gradients according to the indices
            secure_agg_client(rank,world_size, update_flat_spars, temp_length,t_threshold, 1)
            # dist.send(update_flat_spars, dst=world_size)

            delta_w = torch.zeros_like(update_flat_spars)
            dist.broadcast(delta_w, src=world_size)

            # update model
            w_flat_last.index_add_(0, indices, delta_w)
            w_flat = w_flat_last.clone().detach()

            unflatten(model, w_flat)

    return


def validate(val_loader, model, criterion):
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input_var = torch.autograd.Variable(input.cuda())
        target = target.cuda()
        target_var = torch.autograd.Variable(target)

        # compute output
        with torch.no_grad():
            output = model(input_var)
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1[0], input.size(0))

    return losses.avg, top1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def flatten_all(model):
    vec = []
    for param in model.parameters():
        vec.append(param.data.view(-1))
    for b in model.buffers():
        vec.append(b.data.float().view(-1))
    return torch.cat(vec)


def unflatten_all(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param
    for b in model.buffers():
        num_param = torch.prod(torch.LongTensor(list(b.size())))
        b.data = vec[pointer:pointer + num_param].view(b.size())
        pointer += num_param


def flatten(model):
    vec = []
    for param in model.parameters():
        vec.append(param.data.view(-1))
    return torch.cat(vec)


def unflatten(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param


def flatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        vec[pointer:pointer + num_param] = param.grad.data.view(-1)
        pointer += num_param


def unflatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.grad.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param

if __name__ == '__main__':
    dist.init_process_group('mpi')
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    if rank == world_size - 1:
        coordinate(rank, world_size - 1)
    else:
        run(rank, world_size - 1)
