import copy
import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import numpy as np
from utils.get_closure import get_optimizer_closure
import sys
sys.path.append('./')


def average(params_lst):
    with torch.no_grad():
        averaged = [0]*len(params_lst[0])
        for i in range(len(params_lst[0])):
            for params in params_lst:
                averaged[i] += params[i]/len(params_lst)
    return averaged


def L2_distance(params_lst1, params_lst2):
    ans = 0
    for p1, p2 in zip(params_lst1, params_lst2):
        ans += torch.norm(p1-p2) ** 2
    return ans ** 0.5


class GradientCalculator(optim.Optimizer):
    def __init__(self, model, weight_decay):
        super(GradientCalculator, self).__init__(
            params=model.parameters(), defaults={})
        self._model = model
        self._closure = get_optimizer_closure(model)
        self._weight_decay = weight_decay

    def _add_l2(self, grad, p):
        return grad.add_(p, alpha=self._weight_decay)

    def _compute_stochastic_grad(self, input, target):
        if torch.cuda.is_available():
            input, target = input.cuda(), target.cuda()
        loss = self._closure(input, target)
        grads = []
        for i, p in enumerate(self.param_groups[0]["params"]):
            if p.grad is None:
                grads.append(0)
            grads.append(self._add_l2(grad=p.grad, p=p))
        return grads, loss

    def copy_params(self, group):
        for (p_to, p_from) in zip(self.param_groups[0]["params"], group["params"]):
            p_to.data = copy.deepcopy(p_from.data)


class Client(GradientCalculator):
    def __init__(self, model, eta, weight_decay, train_loader, n_iters, **kargs):
        super(Client, self).__init__(model, weight_decay)
        self._eta = eta
        self._train_loader = train_loader
        self._n_iters = n_iters

    def train(self):
        data_iterator = self._train_loader.__iter__()
        for _ in range(self._n_iters):
            try:
                input, target = next(data_iterator)
            except StopIteration:
                data_iterator = self._train_loader.__iter__()
                input, target = next(data_iterator)
            grads, loss = self._compute_grad_estimator(input, target)
        return grads, loss  # this is the final loss

    def _compute_grad_estimator(self, input, target):
        raise NotImplementedError

    def _update(self, grads, noise=0.0):
        with torch.no_grad():
            for i, p in enumerate(self.param_groups[0]["params"]):
                if p.grad is None:
                    continue
                p.add_(grads[i], alpha=-self._eta)
                if torch.cuda.is_available(): 
                    p.add_(torch.randn(p.size()).cuda(), alpha=noise)
                else: p.add_(torch.randn(p.size()), alpha=noise)


class SGD_Client(Client):
    def _compute_grad_estimator(self, input, target):
        return self._compute_stochastic_grad(input, target)


class Server:
    def __init__(self, model, eta, weight_decay, train_loaders,
                 n_local_iters, n_workers, epsilon, **kargs):
        self._model = model
        self._n_workers = n_workers
        self._optimizers = [self._get_optimizer(model=copy.deepcopy(self._model),
                                                eta=eta, weight_decay=weight_decay,
                                                train_loader=train_loaders[i],
                                                n_iters=n_local_iters)
                            for i in range(self._n_workers)]
        self._update_count = 0
        self._defect_workers = set()
        self.epsilon = epsilon

    def update(self):
        self._update()
        self._update_count += 1

    def get_model(self):
        return self._model

    def _update(self):
        raise NotImplementedError

    def _get_optimizer(self, **kargs):
        raise NotImplementedError


class LocalSGD_Server(Server):

    def _update(self):
        grads = {i: None for i in range(self._n_workers)}
        losses = {i: None for i in range(self._n_workers)}
        D, ND = set(), set()
        for i, optimizer in enumerate(self._optimizers):
            if i in self._defect_workers:
                continue
            agent_i_grads, agent_i_loss = optimizer.train()
            grads[i] = agent_i_grads
            losses[i] = agent_i_loss
            #print("agent {} loss: {}".format(i, agent_i_loss))
            if agent_i_loss <= self.epsilon:
                self._defect_workers.add(i)
            if agent_i_loss <= self.epsilon * 1.21:
                D.add(i)
            else:
                ND.add(i)
        self._adaptive_aggregation(grads, losses, D, ND)
        #self._communicate_params()
        self._model = copy.deepcopy(self._optimizers[0]._model)

    def _get_optimizer(self, **kargs):
        return SGD_Client(**kargs)
    
    def _communicate_params(self): # mini-batch SGD
        params_lst = []
        #print("defect workers: ", self._defect_workers)
        for i, optimizer in enumerate(self._optimizers):
            if i in self._defect_workers:
                continue
            params_lst.append(optimizer.param_groups[0]["params"])
        if len(params_lst) > 0: 
            averaged = average(params_lst)
            for optimizer in self._optimizers:
                optimizer.copy_params({"params": averaged})

    def _adaptive_aggregation(self, grads, losses, D, ND): # ADA-GD
        if len(D) == self._n_workers: # Case 3
            return
        elif len(ND) == self._n_workers: # Case 2
            params_lst = self._optimizers[0].param_groups[0]["params"]
            averaged_grads = grads[0]
            for i in range(1, self._n_workers):
                averaged_grads = [x + y for x, y in zip(averaged_grads, grads[i])]
            averaged_grads = [x / self._n_workers for x in averaged_grads]
            # normalized averaged_grads if norm > 1
            norms = [torch.norm(x) for x in averaged_grads]
            for i, norm in enumerate(norms):
                if norm > 1:
                    averaged_grads[i] = averaged_grads[i] / norm
            with torch.no_grad():
                for i, p in enumerate(params_lst):
                    if p.grad is None:
                        continue
                    p.add_(averaged_grads[i], alpha=-self._optimizers[0]._eta)
            for i in range(self._n_workers):
                self._optimizers[i].copy_params({"params": params_lst})
        else: # Case 1
            assert len(D) == len(ND) == 1
            params_lst = self._optimizers[list(ND)[0]].param_groups[0]["params"]
            averaged_grads = grads[list(ND)[0]]
            GC_orthogonalization = grads[list(D)[0]]
            for i in D:
                for j, grad in enumerate(grads[i]):
                    GC_orthogonalization[j] = (torch.dot(grad.view(-1), averaged_grads[j].view(-1))/(torch.norm(grad)**2)) * grad
            for i, grad in enumerate(averaged_grads):
                averaged_grads[i] = averaged_grads[i] - GC_orthogonalization[i]
            norms = [torch.norm(x) for x in averaged_grads]
            for i, norm in enumerate(norms):
                if norm > 1:
                    averaged_grads[i] = averaged_grads[i] / norm
            with torch.no_grad():
                for i, p in enumerate(params_lst):
                    if p.grad is None:
                        continue
                    p.add_(averaged_grads[i], alpha=-self._optimizers[list(ND)[0]]._eta)
            for i in range(self._n_workers):
                self._optimizers[i].copy_params({"params": params_lst})
                    
