
import torch.nn.functional as F
from torch import nn, autograd
from torch.utils.data import DataLoader
import torch
import copy
import numpy as np
from torch.nn.utils import parameters_to_vector

###############################################
# Routine for computing test accuracy
##############################################
def calc_loss(out, y):
    out = torch.swapaxes(out, 1, 2)
    return nn.CrossEntropyLoss()(out, y)

unk_symbol = 1
pad_symbol = 0

def calc_pred(out):
    _, pred = torch.max(out.data, -1)
    pred[pred==unk_symbol] = -1
    pred[pred==pad_symbol] = -1

    return pred

def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    total = 0
    data_loader = DataLoader(datatest, batch_size=min(len(datatest),args.bs))
    for idx, (data, target) in enumerate(data_loader):

        data, target = data.to(args.device), target.to(args.device)
        target = target.type(torch.LongTensor).to(args.device)
        log_probs = net_g(data)

        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += (y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()).numpy()

    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)

    return accuracy, test_loss


def train_and_test(net, datatrain, dataval, datatest, args, local_epochs=None):
    net.train()

    optimizer = torch.optim.SGD(net.parameters(), lr=args.local_eta, momentum=0)
    loss_func = nn.CrossEntropyLoss()

    if local_epochs == None:
        local_epochs = args.local_train_ep

    for _ in range(local_epochs):

        data_loader = DataLoader(datatrain, batch_size=min(len(datatrain),args.bs), shuffle=True)
        data_iter = iter(data_loader)
        (images, labels) = next(data_iter)
        images, labels = images.to(args.device), labels.to(args.device)
        labels = labels.type(torch.LongTensor).to(args.device)
        net.zero_grad()
        log_probs = net(images)

        loss = loss_func(log_probs, labels)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        acc, val_loss = test_img(net, dataval, args)
        acc, test_loss = test_img(net, datatest, args)


    return acc, val_loss, test_loss


class LocalUpdate(object):
    def __init__(self, args, dataset=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.dataset = dataset

    def train_and_sketch(self, net):
        net.train()

        optimizer = torch.optim.SGD(net.parameters(), lr= self.args.eta, momentum=0)
        prev_net = copy.deepcopy(net)
        batch_loss = []
        for _ in range(self.args.train_ep):
            ldr_train = DataLoader(self.dataset, batch_size=self.args.bs, shuffle=True)
            data_iter = iter(ldr_train)

            (images, labels) = next(data_iter)

            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            net.zero_grad()
            log_probs = net(images)

            loss = self.loss_func(log_probs, labels)
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item())

        # Sparsify the difference between previous and current model
        with torch.no_grad():
            # Comoute the difference between previous and current model
            vec_curr = parameters_to_vector(net.parameters())
            vec_prev = parameters_to_vector(prev_net.parameters())
            params_delta_vec = vec_curr - vec_prev

            params_delta_vec_np = params_delta_vec.cpu().numpy()
            model_to_return = params_delta_vec_np

        return model_to_return



    def train_and_sketch_mwfed(self, net, scale):
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr= self.args.eta, momentum=0)
        prev_net = copy.deepcopy(net)
        batch_loss = []
        local_ep = int(self.args.train_ep*self.args.size*scale)

        for _ in range(local_ep):
            ldr_train = DataLoader(self.dataset, batch_size=self.args.bs, shuffle=True)
            data_iter = iter(ldr_train)

            (images, labels) = next(data_iter)

            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            net.zero_grad()
            log_probs = net(images)
            loss = self.loss_func(log_probs, labels)
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item())

        # Sparsify the difference between previous and current model
        with torch.no_grad():
            # Comoute the difference between previous and current model
            vec_curr = parameters_to_vector(net.parameters())
            vec_prev = parameters_to_vector(prev_net.parameters())
            params_delta_vec = vec_curr - vec_prev

            params_delta_vec_np = params_delta_vec.cpu().numpy()
            model_to_return = params_delta_vec_np

        return model_to_return*local_ep, local_ep

    def train_and_sketch_local(self, net):
        net.train()

        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.eta, momentum=0)
        batch_loss = []
        for _ in range(self.args.local_tune_ep):
            ldr_train = DataLoader(self.dataset, batch_size=self.args.bs, shuffle=True)
            data_iter = iter(ldr_train)

            (images, labels) = next(data_iter)

            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            net.zero_grad()
            log_probs = net(images)

            loss = self.loss_func(log_probs, labels)
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item())

        return net

    def train_and_sketch_fedprox(self, net):
        net.train()
        l2loss = torch.nn.MSELoss(reduction='sum').to(self.args.device)
        mu = 0.05

        optimizer = torch.optim.SGD(net.parameters(), lr= self.args.eta, momentum=0)

        prev_net = copy.deepcopy(net)
        prev_net.eval()

        batch_loss = []
        step_count = 0
        torch.autograd.set_detect_anomaly(True)
        for _ in range(self.args.train_ep):

            ldr_train = DataLoader(self.dataset, batch_size=self.args.bs, shuffle=True)
            data_iter = iter(ldr_train)
            (images, labels) = next(data_iter)
            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            net.zero_grad()
            log_probs = net(images)
            loss = self.loss_func(log_probs, labels)
            l1 = torch.tensor(0., requires_grad=True)
            for p, p_prev in zip(net.parameters(), prev_net.parameters()):
                l1 = l1 + l2loss(p,p_prev)**2

            (loss + (mu/2) * l1).backward(retain_graph = True)
            optimizer.step()
            optimizer.zero_grad()

            batch_loss.append(loss.item())

        # Sparsify the difference between previous and current model
        with torch.no_grad():
            # Comoute the difference between previous and current model
            vec_curr = parameters_to_vector(net.parameters())
            vec_prev = parameters_to_vector(prev_net.parameters())
            params_delta_vec = vec_curr - vec_prev

            params_delta_vec_np = params_delta_vec.cpu().numpy()
            model_to_return = params_delta_vec_np

        return model_to_return


    def train_and_sketch_qffl(self, net):
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr= self.args.eta, momentum=0)

        prev_net = copy.deepcopy(net)
        prev_net.eval()

        batch_loss = []
        torch.autograd.set_detect_anomaly(True)
        for _ in range(self.args.train_ep):

            ldr_train = DataLoader(self.dataset, batch_size=min(len(self.dataset),self.args.bs), shuffle=True)
            data_iter = iter(ldr_train)
            (images, labels) = next(data_iter)
            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            net.zero_grad()
            log_probs = net(images)
            loss = self.loss_func(log_probs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            batch_loss.append(loss.item())

        # Sparsify the difference between previous and current model
        with torch.no_grad():
            # Comoute the difference between previous and current model
            vec_curr = parameters_to_vector(net.parameters())
            vec_prev = parameters_to_vector(prev_net.parameters())
            params_delta_vec = vec_curr - vec_prev

            params_delta_vec_np = params_delta_vec.cpu().numpy()
            model_to_return = (params_delta_vec_np)*(1/self.args.eta)

        return model_to_return

    def train_and_sketch_perfedavg(self, net):

        other_lr = self.args.eta
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr=other_lr, momentum=0)
        prev_net = copy.deepcopy(net)

        for _ in range(self.args.train_ep):

            ldr_train = DataLoader(self.dataset, batch_size=min(len(self.dataset),self.args.bs), shuffle=True)
            data_iter = iter(ldr_train)
            ldr_train_2 = DataLoader(self.dataset, batch_size=min(len(self.dataset),self.args.bs), shuffle=True)
            data_iter2 = iter(ldr_train_2)


            # update for wtilde
            (images, labels) = next(data_iter)
            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            net.zero_grad()
            log_probs = net(images)
            loss = self.loss_func(log_probs, labels)
            loss.backward()
            optimizer.step()                    # w tilde
            optimizer.zero_grad()

            (images, labels) = next(data_iter2)
            images, labels = images.to(self.args.device), labels.to(self.args.device)
            labels = labels.type(torch.LongTensor).to(self.args.device)
            log_probs = net(images)
            loss = self.loss_func(log_probs, labels)
            loss.backward()
            optimizer.step()  # w tilde
            optimizer.zero_grad()

        with torch.no_grad():
            # Comoute the difference between previous and current model
            vec_curr = parameters_to_vector(net.parameters())
            vec_prev = parameters_to_vector(prev_net.parameters())
            params_delta_vec = vec_curr - vec_prev

            params_delta_vec_np = params_delta_vec.cpu().numpy()
            model_to_return = params_delta_vec_np

        return model_to_return


def aggr_func(alg, ind, grad_locals, val_loss_locals, local_val_loss):
    sum = 0
    grad_avg = np.zeros((grad_locals[0].shape[0]))
    scale_tmp = np.zeros((len(ind)))

    # for i in range(n):
    for (i, j) in enumerate(ind):

        scale = 1
        tmp_loss = val_loss_locals[i] - local_val_loss[i]

        if alg == 'FedAvg':
            scale = 1

        if alg == 'Sigmoid':
            tmp_loss = 2 * tmp_loss
            tmp_loss_weight = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))
            scale = tmp_loss_weight * (1 - tmp_loss_weight)

        if alg == 'ReLU':
            if tmp_loss > 0:
                scale = 1  # (self.size - sum(self.pr))
            else:
                scale = 0

        if alg == 'Leave':
            if tmp_loss > 0:
                scale = 0  # (self.size - sum(self.pr))
            else:
                scale = 1

        if alg == 'Softplus':
            tmp_loss = 2 * tmp_loss
            scale = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))

        scale_tmp[i] = scale
        sum = sum + scale
        grad_avg = grad_avg + scale * grad_locals[i]

    if alg == 'qFFL' or alg == 'MW-Fed':
        grad_avg = grad_avg
        scale_tmp = scale_tmp

    else:
        grad_avg = grad_avg / sum
        scale_tmp = scale_tmp / sum

    return grad_avg, scale_tmp


def make_gradient_dict(dict):
    dictis = {}
    i=0
    for p in dict.parameters():
        dictis[i] = copy.deepcopy(p.grad.data)
        i+=1

    return dictis



