#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics
import time
import datetime

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None, iters=None):
        self.args = args
        #self.loss_func = nn.MSELoss(reduction='mean')
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, num_workers=args.num_workers)
        self.iters = iters


        
    def train(self, net, c_local, c_global):

        w_temp = None
        
        #compute number of epochs to run
        num_batch = len(self.ldr_train)
        eps = int(self.iters/num_batch)
        rem_iters = self.iters % num_batch
        
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        count = 0

        # we need to do the one_hot_encoding for linear regression.
        
        for ep in range(0,eps):
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # labels = labels.to(torch.float64)
                images = images.to(torch.float64)
                #labels_one_hot_encoded = nn.functional.one_hot(labels,10)
                #labels_one_hot_encoded = labels_one_hot_encoded.to(torch.float64)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()

                
        if rem_iters != 0:        
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # labels = labels.to(torch.float64)
                images = images.to(torch.float64)
                # labels_one_hot_encoded = nn.functional.one_hot(labels,10)
                # labels_one_hot_encoded = labels_one_hot_encoded.to(torch.float64)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()

                count = count + 1
                if count == rem_iters:
                     break
        return net.state_dict()
