import torch
import numpy as np


class APFLOptimizer:
    def __init__(self, device, global_model, local_model, kwargs):
        self.device = device
        self.global_model = global_model
        self.local_model = local_model

        self.lr = kwargs['lr']
        self.alpha = kwargs['alpha']
        self.adaptive_alpha = kwargs['adaptive_alpha']
        self.g_lambda = kwargs['lambda']
        self.momentum = kwargs['momentum']
        self.wd = kwargs['wd']
        self.accumulation_steps = kwargs['accumulation_steps']

        self.global_opt = torch.optim.SGD(
            self.global_model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.wd
        )
        self.local_opt = torch.optim.SGD(
            self.local_model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.wd
        )

    def step(self, x_accumulator):
        self.global_opt.zero_grad()
        for x1, x2 in x_accumulator:
            p1g, z2g, p2g, z1g = self.global_model.predict(x1, x2)
            global_loss = self.global_model.loss_fn(p1g, z2g, p2g, z1g)
            global_loss.backward()
        self.global_opt.step()

        if self.adaptive_alpha:
            self.alpha = alpha_update(self.global_model, self.local_model, self.alpha, self.lr)

        mix_loss = 0
        self.local_opt.zero_grad()
        for x1, x2 in x_accumulator:
            p1m, z2m, p2m, z1m = [linear_mix(l, p, self.alpha) for l, p in
                                  zip(self.local_model.predict(x1, x2), self.global_model.predict(x1, x2))]
            mix_loss = self.local_model(p1m, z2m, p2m, z1m)
            mix_loss.backward()
        self.local_opt.step()

        return mix_loss


def alpha_update(global_model, local_model, alpha, lr):
    grad_alpha = 0
    for l_params, p_params in zip(global_model.parameters(), local_model.parameters()):
        dif = p_params.data - l_params.data
        grad = alpha * p_params.grad.data + (1 - alpha) * l_params.grad.data
        grad_alpha += dif.view(-1).T.dot(grad.view(-1))

    grad_alpha += 0.02 * alpha
    alpha_n = alpha - lr * grad_alpha
    alpha_n = np.clip(alpha_n.item(), 0.0, 1.0)
    return alpha_n


def linear_mix(a, b, ratio):
    return ratio * a + (1 - ratio) * b
