import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import math

def gaussian_noise(data_shape, clip_constant, sigma, device=None):
    """
    Gaussian noise
    """
    return torch.normal(0, sigma * clip_constant, data_shape).to(device)


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label



class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None, shadow = False, PERCN_OF_SHADOW = None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []

        db_splited = DatasetSplit(dataset, idxs)


        if shadow == True:
            db_to_use = [db_splited[i] for i in range(int(PERCN_OF_SHADOW * len(db_splited)))]
            self.ldr_train = DataLoader(db_to_use, batch_size=self.args.local_bs, shuffle=True)

        else:
            db_to_use = [db_splited[i] for i in range(int(PERCN_OF_SHADOW * len(db_splited)),len(db_splited))]
            self.ldr_train = DataLoader(db_to_use, batch_size=self.args.local_bs, shuffle=True)
           
        #print ("self ldr train is", len(self.ldr_train))

    def train(self, net):
        net.train()
        # train and update

        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)

                net.zero_grad()
                log_probs = net(images)

                loss = self.loss_func(log_probs, labels)

                loss.backward()
                ##### DP #####
                """
                torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) #gradient clipping

                for name, param in net.named_parameters(): #allocation of current gradient to noised_gradient variable
                    clipped_grads[name] += param.grad
                for name, param in net.named_parameters(): #current gradient+gaussain_noise 
                    clipped_grads[name]+=gaussian_noise(clipped_grads[name].shape, 0.25, sigma, self.args.device)
                for name, param in net.named_parameters():
                    clipped_grads[name]/=self.args.local_bs
                for name, param in net.named_parameters(): #allocation of noised gradient to model gradient
                    temp=clipped_grads[name].squeeze(dim=0)
                    param.grad = temp
                """

                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss) / len(batch_loss))

        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)