#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
import copy
import time
from torch import nn
from torch.utils.data import DataLoader
from FedProxOptimizer import FedProxOptimizer


class LocalUpdate(object):
    def __init__(self, args, dataset, logger):
        self.args = args
        self.logger = logger
        self.trainloader = DataLoader(dataset,
                                      batch_size=self.args.local_bs, shuffle=True)
        self.device = 'cuda' if args.gpu is not None else 'cpu'
        # self.criterion = nn.NLLLoss().to(self.device)
        # if args.model == 'logistic' or args.model == 'twolayer':
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        # if args.dataset == "sent140":
        #     self.criterion = nn.BCELoss().to(self.device)


    def update_weights(self, model):
        # Set mode to train model
        model.train()
        epoch_loss = []
        if self.args.algorithm == "FedAvg":
#             optimizer = torch.optim.SGD(model.parameters(),
#                                         lr=self.args.lr,
#                                         momentum=self.args.momentum,
#                                         nesterov=False,
#                                         weight_decay=0)
            optimizer = FedProxOptimizer(model.parameters(),
                                         lr=self.args.lr,
                                         mu=0,
                                         momentum=self.args.momentum,
                                         nesterov=False,
                                         weight_decay=0)
        else:
            optimizer = FedProxOptimizer(model.parameters(),
                                     lr=self.args.lr,
                                     mu=self.args.prox_param,
                                     momentum=self.args.momentum,
                                     nesterov=False,
                                     weight_decay=0)
            
        self.noise = None
        if self.args.alpha > 0:
#             print(time.time())
            seed = time.time()
            torch.random.manual_seed(int((seed - int(seed)) * 1e3))
            grad_dire = []
            norm_sepe = 0.0
            for group in optimizer.param_groups:
                for param in group['params']:
                    grad_dire.append(torch.rand(param.shape, device=self.device)-0.5)
                    norm_sepe += torch.norm(grad_dire[-1])
            Gamma = torch.distributions.gamma.Gamma(self.args.feature_len, self.args.alpha)
            norm = Gamma.sample()
            norm.to(self.device)
            coef = norm / norm_sepe
#             print(norm, norm_sepe)
            for item in grad_dire:
                item *= coef
            self.noise = grad_dire

        early_stop = self.args.local_ep
        flag = True
        # sample = torch.rand(1, device=self.device)
        if torch.rand(1, device=self.device) < self.args.straggler:
            early_stop = int(torch.torch.rand(1) * self.args.local_ep)
            while early_stop == 0:
                early_stop = int(torch.torch.rand(1) * self.args.local_ep)
            flag = False

        for iter in range(self.args.local_ep):
            if iter >= early_stop:
                break
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                if self.args.dataset == 'sent140':
                    images, labels = images.to(self.device), labels.to(device=self.device)
                elif self.args.dataset == 'shakespeare':
                    images, labels = images.to(self.device), labels.to(device=self.device, dtype=torch.float)
                else:
                    images, labels = images.to(self.device, dtype=torch.float), labels.to(device=self.device, dtype=torch.long)
                    images = images.flatten(start_dim=1)
                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step(noise=self.noise)


        w = copy.deepcopy(model.state_dict())
        if self.args.clip > 0 and self.args.alpha == 0:
            w = self.private(w)
        return w, flag

    def private(self, w):
        #clip
        torch.cuda.empty_cache()
        def get_norm(w):
            norm = 0
            w_ls = []
            for key in w.keys():
                w_temp = w[key].flatten()
                w_ls.append(w_temp)
                norm += torch.norm(w_temp)
            w_total = torch.hstack(w_ls)
#             print(torch.norm(w_total))
#             return norm
            return torch.norm(w_total)
        norm = get_norm(w)
#         print(norm)
        if norm > self.args.clip:
            for key in w.keys():
                w[key] *= (self.args.clip / norm)
        #add noise
        seed = time.time()
        torch.random.manual_seed(int((seed - int(seed)) * 1e3))
        for key in w.keys():
#             print(w[key].shape)
            w[key] += torch.normal(0, self.args.sigma ** 2, size=tuple(w[key].shape), device=self.device)
        return w

    def inference(self, model):
        """ Returns the inference accuracy and loss.
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        cnt = 0
        for batch_idx, (images, labels) in enumerate(self.trainloader):
            if self.args.dataset == 'sent140':
                images, labels = images.to(self.device), labels.to(device=self.device)
            elif self.args.dataset == 'shakespeare':
                images, labels = images.to(self.device), labels.to(device=self.device, dtype=torch.float)
            else:
                images, labels = images.to(self.device, dtype=torch.float), labels.to(device=self.device, dtype=torch.long)
                images = images.flatten(start_dim=1)
            outputs = model(images)

            batch_loss = self.criterion(outputs, labels)
            batch_loss *= labels.shape[0]
            cnt += labels.shape[0]
            loss += batch_loss.item()

            _, pred_labels = torch.max(outputs, 1)
            if self.args.dataset == 'shakespeare':
                _, labels = torch.max(labels, 1)
            else:
                pred_labels = pred_labels.reshape(labels.shape)
            correct += torch.sum(pred_labels == labels).item()
            total += len(labels)

        accuracy = correct/total
        return accuracy, loss / cnt
#