"""
Code of BASGD Algorithm

This file only contains the core part of our code.
"""


def byzt(g, rank, q, epoch, mode):
    if rank < byzt_num:
        if mode == '10gAtk':
            return torch.mul(g, -10)
        elif mode == 'ranDisturbAtk':
            return torch.add(g, torch.randn_like(g).mul_(0.2 * torch.norm(g, 2)))
        elif mode == 'noAtk':
            return g
    else:
        return g


def aggr(g_history, world_size, q, current_buffer, mode):
    if mode == 'median':
        length = len(g_history)
        half_length = int((length - 1) / 2)
        g_list = []
        for i in range(len(g_history)):
            g_list.append(g_history[i])
        g_list, _ = torch.sort(torch.stack(g_list), dim=0)
        g = torch.mean(g_list[half_length: length-half_length], dim=0)
        return g

    elif mode == 'trmean':
        length = len(g_history)
        g_list = []
        for i in range(len(g_history)):
            g_list.append(g_history[i])
        g_list, _ = torch.sort(torch.stack(g_list), dim=0)
        g = torch.mean(g_list[q:length-q], dim=0)
        return g


def coordinate(rank, world_size):
    # server
    
    args = parser.parse_args()
    current_lr = args.lr
    q = args.q
    buffer_num = args.B
    if buffer_num == -1:
        buffer_num = world_size
    adjust = [80, 120]
    model = resnet20()
    model = model.cuda()
    model_flat = flatten_all(model)
    w_flat = flatten(model)
    g_flat = torch.zeros_like(w_flat)

    # initialize buffer
    g_history = []
    for i in range(buffer_num):
        g_history.append(torch.zeros_like(w_flat))
    dist.broadcast(model_flat, world_size)

    cudnn.benchmark = True

    time_stamp = 0
    num_b = [0] * buffer_num
    received_buffer = 0
    for epoch in range(args.epochs):

        # adjust learning rate
        if epoch in adjust:
            current_lr = current_lr * 0.1

        for i in range(len(train_loader)*world_size):
            g_flat.zero_()
            src = dist.recv(g_flat, tag=111)
            buffer = src % len(g_history)
            
            if num_b[buffer] == 0:
                received_buffer += 1
                num_b[buffer] += 1
                g_history[buffer].copy_(g_flat)
            else:
                num_b[buffer] += 1
                g_history[buffer].mul_((num_b[buffer] - 1)/num_b[buffer])
                g_flat.div_(num_b[buffer])
                g_history[buffer].add_(g_flat)

            if received_buffer == buffer_num:

                # update parameter
                w_flat.add_(-current_lr, aggr(g_history, world_size, q,
                            current_buffer=buffer, mode=aggr_mode))
                time_stamp += 1

                # reset buffer
                received_buffer = 0
                num_b = [0] * buffer_num

            # send new parameters back
            dist.send(w_flat, src, tag=222)


def run(rank, world_size):
    # worker
    
    args = parser.parse_args()
    q = args.q

    model = resnet20()
    model = model.cuda()
    model_flat = flatten_all(model)
    dist.broadcast(model_flat, world_size)
    unflatten_all(model, model_flat)

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    for epoch in range(args.epochs):
        # train for one epoch
        train_sampler.set_epoch(epoch)
        train(train_loader, model, criterion, epoch, rank, world_size, q)


def train(train_loader, model, criterion, epoch, rank, world_size, q):
    wd = 0.0001
    w_flat = flatten(model)
    g_flat = torch.zeros_like(w_flat)

    for i, (input, target) in enumerate(train_loader):

        input_var = torch.autograd.Variable(input.cuda())
        target = target.cuda()
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # compute gradient
        model.zero_grad()
        loss.backward()
        flatten_g(model, g_flat)
        g_flat.add_(wd, w_flat)

        # communicate with server
        dist.send(byzt(g_flat, rank, q, epoch, mode=byzt_mode), world_size, tag=111)
        dist.recv(w_flat, world_size, tag=222)
        unflatten(model, w_flat)
