import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import copy

# split data for clients
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        # image, label = self.dataset[self.idxs[item]]
        # # return torch.tensor(image), torch.tensor(label)
        # return image, torch.tensor(label)
        old_idx, image, label = self.dataset[self.idxs[item]]
        return image, torch.tensor(label)

# sample iid data for CIFAR
def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

# FedAvg for FL
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

# Adagrad
class Adagrad:
    def __init__(self, model_params, lr=1e-2, init_acc_sqr_grad=0, eps=1e-10, y_t=None):
        self.model_params = list(model_params)
        self.lr = lr
        self.acc_sqr_grads = [torch.full_like(p, init_acc_sqr_grad) for p in self.model_params]
        self.eps = eps
        if y_t is not None:
            self.y_t = None
        else:
            self.y_t = torch.tensor(y_t)

    def zero_grad(self):
        for param in self.model_params:
            param.grad = None
    
    @torch.no_grad()
    def step(self):
        if self.y_t is None:
            for param, acc_sqr_grad in zip(self.model_params, self.acc_sqr_grads):
                acc_sqr_grad.add_(param.grad * param.grad)
                std = acc_sqr_grad.sqrt().add(self.eps)
                param.sub_((self.lr / std) * param.grad)
        else:
            ## v1
            # for param, acc_sqr_grad in zip(self.model_params, self.acc_sqr_grads):
            #     acc_sqr_grad.add_(param.grad * param.grad)
            #     std = acc_sqr_grad.sqrt().add(self.eps)
            #     param.sub_((self.lr / std) * (param.grad/self.y_t))
            ## v2
            for param, acc_sqr_grad in zip(self.model_params, self.acc_sqr_grads):
                tmp_grad = param.grad/self.y_t
                acc_sqr_grad.add_(tmp_grad * tmp_grad)
                std = acc_sqr_grad.sqrt().add(self.eps)
                param.sub_((self.lr / std) *tmp_grad)

# using now
def update_ykt_batch(model_pre, model_cur, global_round, ykt, trainloader, 
               beta=0.8, lmbda=0.2, local_ep=1):
        # Set mode to train model        
        criterion_pre = nn.CrossEntropyLoss(reduction='none')
        criterion_cur = nn.CrossEntropyLoss(reduction='none')
        lmbda = torch.tensor(lmbda)
        beta = torch.tensor(beta)
        ykt = torch.tensor(ykt)
        tmp_term_pre = torch.tensor(0.0)
        tmp_term_cur = torch.tensor(0.0)
        
        with torch.no_grad():
            batch = next(iter(trainloader))
            images, labels = batch
            images, labels = images.cuda(), labels.cuda()
            # previous model term
            outputs_pre, _ = model_pre(images) 
            loss_pre = criterion_pre(outputs_pre, labels)
            tmp_term_pre +=  torch.mean(torch.exp(loss_pre/lmbda)).item()
            
            # current model term
            outputs_cur, _ = model_cur(images)
            loss_cur = criterion_cur(outputs_cur, labels)
            tmp_term_cur += torch.mean(torch.exp(loss_cur/lmbda)).item()
                    
    
        ykt = (1-beta) * (ykt - tmp_term_pre) + tmp_term_cur
        
        return ykt.item()



# not used 
def update_ykt(model_pre, model_cur, global_round, ykt, trainloader, 
               beta=0.8, lmbda=0.2, local_ep=1):
        # Set mode to train model        
        criterion_pre = nn.CrossEntropyLoss(reduction='none')
        criterion_cur = nn.CrossEntropyLoss(reduction='none')
        lmbda = torch.tensor(lmbda)
        beta = torch.tensor(beta)
        ykt = torch.tensor(ykt)
        tmp_term_pre = torch.tensor(0.0)
        tmp_term_cur = torch.tensor(0.0)
        
        with torch.no_grad():
            for iter in range(local_ep):
                batch_loss = []
                for batch_idx, (images, labels) in enumerate(trainloader):
                    images, labels = images.cuda(), labels.cuda()
                    # previous model term
                    outputs_pre, _ = model_pre(images)
                    loss_pre = criterion_pre(outputs_pre, labels)
                    tmp_term_pre +=  torch.mean(torch.exp(loss_pre/lmbda)).item()
                    
                    # current model term
                    outputs_cur, _ = model_cur(images)
                    loss_cur = criterion_cur(outputs_cur, labels)
                    tmp_term_cur += torch.mean(torch.exp(loss_cur/lmbda)).item()
                    
    
        ykt = (1-beta) * (ykt - tmp_term_pre) + tmp_term_cur
        
        return ykt.item()

# original method for local weights update
# for the local update using SGD or Adam
def update_x_k_v0(model_cur, y_t, global_round, trainloader, 
                  lmbda=10, eta=0.1, local_ep=1, args=None,
                  verbose=False):
    # Set mode to train model
    model = model_cur
    model.train()
    epoch_loss = []
    criterion_cur = nn.CrossEntropyLoss()
    
    # Set optimizer for the local updates
    if args.local_opt == 'sgd':
        # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5)
        optimizer = torch.optim.SGD(model.parameters(), lr=args.curlr, momentum=0.5)
        # optimizer = torch.optim.SGD(model.parameters(), lr=args.curlr, momentum=0.9, weight_decay=args.weight_decay)
    elif args.local_opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.curlr, weight_decay=1e-4)
    elif args.local_opt == 'adagrad':
        # optimizer = Adagrad(model.parameters(), lr=args.curlr, init_acc_sqr_grad=0.001, y_t=None)
        optimizer = Adagrad(model.parameters(), lr=args.curlr, init_acc_sqr_grad=0.001, y_t=y_t)
        
    for iter in range(local_ep):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(trainloader):
            images, labels = images.cuda(), labels.cuda()

            model.zero_grad()
            log_probs, _= model(images)
            loss = criterion_cur(log_probs, labels)
            loss.backward()
            optimizer.step()

            if verbose and (batch_idx % 10 == 0):
                print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    global_round, iter, batch_idx * len(images),
                    len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.item()))
            # self.logger.add_scalar('loss', loss.item())
            batch_loss.append(loss.item())
        epoch_loss.append(sum(batch_loss)/len(batch_loss))
    # print("epoch loss", sum(epoch_loss) / len(epoch_loss))
    return model.state_dict() 

# v1
def update_x_k_v1(model_cur, y_t, global_round, trainloader,
               lmbda=10, eta=0.1, local_ep=1, args=None):
        # criterion_cur = nn.CrossEntropyLoss(reduce=None)
        criterion_cur = nn.CrossEntropyLoss(reduction='none')
        y_t = torch.tensor(y_t)
        w_grad_state = dict()
        
        # x_t = copy.deepcopy(model_cur.state_dict())
        model_new = copy.deepcopy(model_cur)
                
        # get the first term
        for iter in range(local_ep):
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.cuda(), labels.cuda()
                
                # current model term
                outputs_cur, _ = model_new(images)
                loss_cur = criterion_cur(outputs_cur, labels)
                g_obj = torch.mean(torch.exp(loss_cur/lmbda))
                # f_obj f(g)
                f_g_obj = torch.log(g_obj)
                model_new.zero_grad()
                f_g_obj.backward()
                
                
                for name, param in model_new.named_parameters():  # load the name and value of every layer.
                    if name not in w_grad_state.keys() and param.requires_grad:
                        w_grad_state[name] = param.grad/y_t
                # copy the model parameters of current iteration to "model" for next iteration updates            
                for name, param in model_new.named_parameters():
                    if param.requires_grad:
                        param.data.add_(-eta, w_grad_state[name]) # dw = dw + (-eta)*w
                        # param.data.add_(w_grad_state[name], -eta)                     
            print("global round: ", global_round)
        return model_new.state_dict()

# v2
def update_x_k_v2(model_cur, y_t, global_round, trainloader,
               lmbda=10, eta=0.1, local_ep=1, args=None):
        criterion_cur = nn.CrossEntropyLoss(reduction='mean')
        y_t = torch.tensor(y_t)
        w_grad_state = dict()
        
        # x_t = copy.deepcopy(model_cur.state_dict())
        model_new = copy.deepcopy(model_cur)
                
        # get the first term
        for iter in range(local_ep):
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.cuda(), labels.cuda()
                
                # zero gradients
                model_new.zero_grad(set_to_none=True)
                # foward pass
                outputs_cur, _ = model_new(images)
                loss_cur = criterion_cur(outputs_cur, labels)
                g_obj = torch.mean(torch.exp(loss_cur/lmbda))
                # backward pass
                g_obj.backward()
                
                # update the model parameters, like optimizer.step()
                with torch.no_grad():
                    for name, param in model_new.named_parameters():  # load the name and value of every layer.
                        if name not in w_grad_state.keys() and param.requires_grad:
                            w_grad_state[name] = param.grad/y_t
                        else:
                            if param.requires_grad:
                                w_grad_state[name] = (1-args.curbeta)* w_grad_state[name]  +  args.curbeta *  param.grad/y_t
                                param.data.add_(-args.curlr, w_grad_state[name] + args.weight_decay * param.data)  # for model w, we add weight decay
                            # param.data = param.data - curlr* (w_grad_state[name])
        return model_new.state_dict()         

# v3
def update_x_k_v3(model_cur, y_t, global_round, 
               trainloader,
               lmbda = 10, eta=0.1, local_ep=1, args=None):
        criterion_cur = nn.CrossEntropyLoss(reduction='none')
        y_t = torch.tensor(y_t)
        w_grad_state = dict()
        
        # x_t = copy.deepcopy(model_cur.state_dict())
        model_new = copy.deepcopy(model_cur)
                
        # get the first term
        for iter in range(local_ep):
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.cuda(), labels.cuda()
                
                # current model term
                outputs_cur, _ = model_new(images)
                loss_cur = criterion_cur(outputs_cur, labels)
                g_obj = torch.mean(torch.exp(loss_cur/lmbda))
                
                # f_obj f(g)
                f_obj = torch.log(g_obj)
                model_new.zero_grad()
                f_obj.backward()
                
                for name, param in model_new.named_parameters():  # load the name and value of every layer.
                    if name not in w_grad_state.keys() and param.requires_grad:
                        w_grad_state[name] = param.grad/y_t # v4 no /y_t
                    else:
                        if param.requires_grad:
                            w_grad_state[name] = (1-args.curbeta)* w_grad_state[name]  +  args.curbeta *  param.grad/y_t # v4 no /y_t
                            param.data.add_(-args.curlr, w_grad_state[name] + args.weight_decay * param.data)  # for model w, we add weight decay
        
        return model_new.state_dict()

# v4    
def update_x_k_v4(model_cur, y_t, global_round, 
               trainloader,
               lmbda = 10, eta=0.1, local_ep=1, args=None):
        criterion_cur = nn.CrossEntropyLoss(reduce=None)
        y_t = torch.tensor(y_t)
        w_grad_state = dict()
        
        # x_t = copy.deepcopy(model_cur.state_dict())
        model_new = copy.deepcopy(model_cur)
                
        # get the first term
        for iter in range(local_ep):
            for batch_idx, (images, labels) in enumerate(trainloader):
                images, labels = images.cuda(), labels.cuda()
                
                # current model term
                outputs_cur, _ = model_new(images)
                loss_cur = criterion_cur(outputs_cur, labels)
                g_obj = torch.mean(torch.exp(loss_cur/lmbda))
                
                # f_obj f(g)
                f_obj = torch.log(g_obj)
                model_new.zero_grad()
                f_obj.backward()
                
                for name, param in model_new.named_parameters():  # load the name and value of every layer.
                    if name not in w_grad_state.keys() and param.requires_grad:
                        w_grad_state[name] = param.grad # v3 has /y_t
                    else:
                        if param.requires_grad:
                            w_grad_state[name] = (1-args.curbeta)* w_grad_state[name]  +  args.curbeta *  param.grad # v3 has /y_t
                            param.data.add_(-args.curlr, w_grad_state[name] + args.weight_decay * param.data)  # for model w, we add weight decay
        
        return model_new.state_dict()
