# Update related
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.curlr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.cuda(), labels.cuda()
                net.zero_grad()
                log_probs, _ = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

# Adagrad
class Adagrad:
    def __init__(self, model_params, lr=1e-2, init_acc_sqr_grad=0, eps=1e-10, y_t=None):
        self.model_params = list(model_params)
        self.lr = lr
        self.acc_sqr_grads = [torch.full_like(p, init_acc_sqr_grad) for p in self.model_params]
        self.eps = eps
        if y_t is not None:
            self.y_t = None
        else:
            self.y_t = torch.tensor(y_t)

    def zero_grad(self):
        for param in self.model_params:
            param.grad = None
    
    @torch.no_grad()
    def step(self):
        if self.y_t is None:
            for param, acc_sqr_grad in zip(self.model_params, self.acc_sqr_grads):
                acc_sqr_grad.add_(param.grad * param.grad)
                std = acc_sqr_grad.sqrt().add(self.eps)
                param.sub_((self.lr / std) * param.grad)
        else:
            ## v1
            # for param, acc_sqr_grad in zip(self.model_params, self.acc_sqr_grads):
            #     acc_sqr_grad.add_(param.grad * param.grad)
            #     std = acc_sqr_grad.sqrt().add(self.eps)
            #     param.sub_((self.lr / std) * (param.grad/self.y_t))
            ## v2
            for param, acc_sqr_grad in zip(self.model_params, self.acc_sqr_grads):
                tmp_grad = param.grad/self.y_t
                acc_sqr_grad.add_(tmp_grad * tmp_grad)
                std = acc_sqr_grad.sqrt().add(self.eps)
                param.sub_((self.lr / std) *tmp_grad)

class LocalUpdateDDRO(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.BCEWithLogitsLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
    
    def update_ykt(self, net, w_glob, w_glob_prev, ykt):
        # Set mode to train model        
        criterion_pre = self.loss_func
        criterion_cur = self.loss_func
        
        ykt = torch.tensor(ykt)
        tmp_term_pre = torch.tensor(0.0)
        tmp_term_cur = torch.tensor(0.0)
        beta = torch.tensor(self.args.beta)
        lmbda = torch.tensor(self.args.lmbda)
        
        model_cur = copy.deepcopy(net)
        model_cur.load_state_dict(w_glob)
        model_pre = copy.deepcopy(net)
        model_pre.load_state_dict(w_glob_prev)
        model_cur.eval()
        model_pre.eval()
        with torch.no_grad():
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.cuda(), labels.cuda()
                # previous model term
                outputs_pre = model_pre(images)
                loss_pre = criterion_pre(outputs_pre, labels)
                tmp_term_pre +=  torch.mean(torch.exp(loss_pre/lmbda)).item()
                
                # current model term
                outputs_cur = model_cur(images)
                loss_cur = criterion_cur(outputs_cur, labels)
                tmp_term_cur += torch.mean(torch.exp(loss_cur/lmbda)).item()
                    
        ykt = (1-beta) * ykt  + beta * tmp_term_cur/len(self.ldr_train)
     
        # print(ykt.item())
        # print(ykt.item())
        return ykt.item()
    
    
    def train(self, net, y_t=None):
        net.train()
        # train and update
        #optimizer = Adagrad(net.parameters(), lr=self.args.curlr, init_acc_sqr_grad=0.001, y_t=y_t)
        optimizer = Adagrad(net.parameters(), lr=self.args.curlr, init_acc_sqr_grad=0.001, y_t=y_t)
        
        count = 0
        epoch_loss = []
        # equal to I = args.local_ep * len(self.ldr_train)
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(device), labels.to(device)
                net.zero_grad()
                log_probs= net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()

                # if batch_idx % 10 == 0:
                #     print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #         iter, batch_idx * len(images), len(self.ldr_train.dataset),
                #                100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
                count += 1
                if count >= self.args.I:
                    break
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            if count >= self.args.I:
                    break
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

# FedAVG
import copy
import torch
from torch import nn


def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg


# training culr
def adjust_curlr(epoch, args, optimizer = None):
    if args.epochs == 120 or args.epochs == 150 or args.epochs == 500 : # when training 120 epochs cy
        if epoch <= 5:
            args.curlr = args.lr * epoch / 5
        elif epoch > 90:
            args.curlr = args.lr * 0.1
        else:
            args.curlr = args.lr
    else:
        args.curlr = args.lr
    if optimizer is not None:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.curlr