# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/models/Update.py
# credit goes to: Paul Pu Liang

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import math
import numpy as np
import time
import copy

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs, name=None):
        self.dataset = dataset
        self.idxs = list(idxs)
        self.name = name

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
        self.dataset=dataset
        self.idxs=idxs

    def train(self, net, lr=0.1):
        local_eps = self.args.local_ep
        
        parameters_all = dict(net.named_parameters())
        parameters = []
        train_keys = []
        parameters += [v for k,v in parameters_all.items()]
        train_keys += [k for k,v in parameters_all.items()]
        
        ## local train
        epoch_loss = []
        optimizer=torch.optim.SGD(parameters, momentum=0.9, weight_decay=1e-4, lr=lr)
        for name, param in net.named_parameters():
            param.requires_grad = True
        
        for iter in range(local_eps):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())
            
            epoch_loss.append(sum(batch_loss)/len(batch_loss)) 

        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)