import os
import shutil
import time
import argparse

import numpy
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.distributed as dist
import datetime
import random

from models import *
from aggrs import *
from attacks import *

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--epochs', default=160, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--bs', '--batch-size', default=128, type=int,
                    help='size of a batch')

parser.add_argument('--byz_num', default=0, type=int, metavar='N',
                    help='number of Byzantine workers')
parser.add_argument('--lr', '--learning-rate', default=0.5, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--aggr_rule', default='vaSGD', type=str, metavar='N',
                    help='robust aggregation rule')
parser.add_argument('--byz_mode', default='noAtk', type=str, metavar='N',
                    help='Byzantine attack method')

parser.add_argument('--FoE_epsilon', default=0.5, type=float, metavar='N',
                    help='hyper-parameter of FoE attack')
parser.add_argument('--tau', default=0.1, type=float, metavar='N',
                    help='clipping hyper-parameter of CC')
parser.add_argument('--robust_iters', default=10, type=int, metavar='N',
                    help='number of iterations in robust aggregation')

args = parser.parse_args()
aggr_rule = args.aggr_rule
byz_mode = args.byz_mode
q = args.byz_num
bs = args.bs

momentum = 0.9
wd = 0.0001


def coordinate(rank, world_size):
    lr = args.lr
    tau = args.tau
    robust_iters = args.robust_iters
    FoE_epsilon = args.FoE_epsilon

    model = resnet20()
    model = model.cuda()
    model_flat = flatten_all(model)

    w_flat = flatten(model)
    u_flat = torch.zeros_like(w_flat)

    dist.broadcast(model_flat, world_size)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    # Data loading code
    train_transform = transforms.Compose(
        [transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    val_transform = transforms.Compose(
       [transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = datasets.CIFAR10(root='./data', train=True, download=False, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, pin_memory=True, shuffle=False, num_workers=2)

    valset = datasets.CIFAR10(root='./data', train=False, download=False, transform=val_transform)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=100, pin_memory=True, shuffle=False, num_workers=2)

    datetime.datetime.now().strftime('%m%d%H%M')
    now_time2 = datetime.datetime.now()
    if momentum == 0:
        output_path = datetime.datetime.strftime(now_time2, 'Resnet20_n' + str(world_size) + '_FTSNGMlocal_'
                                                 + aggr_rule + '_' + byz_mode + '_q' + str(q)
                                                 + '_bs' + str(bs) + '_lr' + str(lr) + '_%m%d%H%M%S')
    else:
        output_path = datetime.datetime.strftime(now_time2, 'Resnet20M_n' + str(world_size) + '_FTSNGMlocal_'
                                                 + aggr_rule + '_' + byz_mode + '_q' + str(q)
                                                 + '_bs' + str(bs) + '_lr' + str(lr) + '_%m%d%H%M%S')

    output_file = open(output_path + '.txt', 'w')
    print("--------now_time:"+datetime.datetime.strftime(now_time2, '%m%d_%H%M%S')+"--------")

    g_recv = []
    for i in range(world_size):
        g_recv.append(torch.zeros_like(w_flat))

    for epoch in range(args.epochs):

        current_lr = lr * 0.5 * (1 + numpy.cos(epoch*numpy.pi/args.epochs))
        print('server\'s epoch: '+str(epoch))

        dist.barrier()
        t1 = time.time()
        # dissimilarity = []
        for i in range(len(train_loader)-1):

            for j in range(world_size):
                g_recv[j].zero_()
                dist.recv(g_recv[j], j)

            # omniscient attack
            for j in range(world_size):
                g_recv[j] = omniscient_attack_on_server(byz_mode, q, j, g_recv[j],
                                                        good_vectors=g_recv[q:],
                                                        epsilon=FoE_epsilon)

            # robust aggregation
            u_flat = robust_aggregate(rule=aggr_rule, vectors=g_recv, byz_num=q,
                                      tau=tau, iterations=robust_iters)
            u_flat.div_(torch.norm(u_flat, 2))
            w_flat.add_(-current_lr, u_flat)

            dist.broadcast(w_flat, world_size)

        dist.barrier()
        time_cost = time.time() - t1

        model_flat.zero_()
        dist.recv(model_flat, src=world_size-1)
        unflatten_all(model, model_flat)

        loss, _ = validate(train_loader, model, criterion)
        _, prec1 = validate(val_loader, model, criterion)
        output_file.write('%d %3f %3f %3f \n' %
                          (epoch, time_cost, loss, prec1))
        output_file.flush()

    output_file.close()

    print('training finished.')


def run(rank, world_size):
    q = args.byz_num
    print('Start node: %d  Total: %3d' % (rank, world_size))

    # model = vgg19_bn(num_classes=10)
    model = resnet20()
    model = model.cuda()
    model_flat = flatten_all(model)
    dist.broadcast(model_flat, world_size)
    unflatten_all(model, model_flat)
    u_flat = flatten(model)
    u_flat.zero_()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # Data loading code
    train_transform = transforms.Compose(
        [transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = datasets.CIFAR10(root='./data', train=True, download=False, transform=train_transform)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=world_size, rank=rank)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs//world_size, pin_memory=True, shuffle=False,
                                               num_workers=2, sampler=train_sampler)

    for epoch in range(args.epochs):

        dist.barrier()

        # train for one epoch
        train_sampler.set_epoch(epoch)
        train(train_loader, model, criterion, epoch, rank, world_size, q, u_flat)

        dist.barrier()

        # send model to server for testing
        if rank == world_size - 1:
            model_flat = flatten_all(model)
            dist.send(model_flat, dst=world_size)


def train(train_loader, model, criterion, epoch, rank, world_size, q, u_flat):
    # switch to train mode
    model.train()
    w_flat = flatten(model)

    batch_num = len(train_loader)-1

    for i, (input, target) in enumerate(train_loader):
        if i < batch_num:
            g_flat = non_omniscient_attack_ic(byz_mode=byz_mode, byz_num=q, rank=rank, world_size=world_size, model=model,
                                              w_flat=w_flat, criterion=criterion, instances=input, target=target,
                                              weight_decay=wd, class_num=10)

            u_flat.mul_(momentum).add_(1-momentum, g_flat)
            dist.send(u_flat, dst=world_size)

            dist.broadcast(w_flat, world_size)
            unflatten(model, w_flat)


def validate(val_loader, model, criterion):
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input_var = torch.autograd.Variable(input.cuda())
        target = target.cuda()
        target_var = torch.autograd.Variable(target)

        # compute output
        with torch.no_grad():
            output = model(input_var)
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1[0], input.size(0))

    return losses.avg, top1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def flatten_all(model):
    vec = []
    for param in model.parameters():
        vec.append(param.data.view(-1))
    for b in model.buffers():
        vec.append(b.data.float().view(-1))
    return torch.cat(vec)


def unflatten_all(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param
    for b in model.buffers():
        num_param = torch.prod(torch.LongTensor(list(b.size())))
        b.data = vec[pointer:pointer + num_param].view(b.size())
        pointer += num_param


def flatten(model):
    vec = []
    for param in model.parameters():
        vec.append(param.data.view(-1))
    return torch.cat(vec)


def unflatten(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param


def flatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        vec[pointer:pointer + num_param] = param.grad.data.view(-1)
        pointer += num_param


def unflatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.grad.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param


if __name__ == '__main__':
    dist.init_process_group('mpi')
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    if rank == world_size - 1:
        coordinate(rank, world_size - 1)
    else:
        run(rank, world_size - 1)
