#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import random
import math
import torch.optim as optim


class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

class LocalUpdateL(object):
    def __init__(self, args, dataset, idxs, logger, idx, ep):
        self.ep = ep
        self.args = args
        self.logger = logger
        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            dataset, idxs)
        self.device = 'cuda' if args.gpu else 'cpu'
        # Default criterion set to NLL loss function
        self.criterion = nn.NLLLoss().to(self.device)

    def train_val_test(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        # split indexes for train, validation, and test (80, 10, 10)
        test_sel = np.random.choice(np.size(idxs), math.floor(np.size(idxs)/10), replace=False).astype(int)
        # print(type(test_sel))
        idxs_train = np.delete(idxs, test_sel)
        idxs_val = []
        # print(type(idxs))
        idxs_test = idxs[test_sel]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=self.args.local_bs, shuffle=True, drop_last = True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=max(int(len(idxs_val)/10), 1), shuffle=False, drop_last = True)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=max(int(len(idxs_test)/10), 1), shuffle=False, drop_last = True)
        return trainloader, validloader, testloader

    def update_weights(self, model, global_round):
        # Set mode to train model
        model.to(self.device)
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                        momentum=0.5)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
                                         weight_decay=1e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
        if self.ep < 1:
            if self.args.local_ep != 1:
                eps = 2*self.args.local_ep
            else:
                eps = self.args.local_ep
        else:
            eps = self.args.local_ep
        for iter in range(eps):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()

                if self.args.verbose and (batch_idx % 10 == 0):
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
            if len(batch_loss) ==0:
                lbs = 1
            else:
                lbs = len(batch_loss)
            epoch_loss.append(sum(batch_loss)/lbs)

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        """ Returns the inference accuracy and loss.
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)

            # Inference
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct/total
        return accuracy, loss

class LocalUpdateA(object):
    def __init__(self, args, dataset, idxs, logger, idx, ep):
        self.args = args
        self.logger = logger
        if (idx < int(args.num_users)) != bool(ep):
            G1 = True
        else:
            G1 = False
        mod = idx % 5
        lidx = list(idxs)
        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            dataset, idxs)
        self.device = 'cuda' if args.gpu else 'cpu'
        # Default criterion set to NLL loss function
        self.criterion = nn.NLLLoss().to(self.device)

    def train_val_test(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        # split indexes for train, validation, and test (80, 10, 10)
        '''
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]
        '''

        test_sel = np.random.choice(np.size(idxs), math.floor(np.size(idxs)/10), replace=False).astype(int)
        # print(type(test_sel))
        idxs_train = np.delete(idxs, test_sel)
        idxs_val = []
        # print(type(idxs))
        idxs_test = idxs[test_sel]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=self.args.local_bs, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=max(int(len(idxs_val)/10), 1), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=max(int(len(idxs_test)/10), 1), shuffle=False)
        return trainloader, validloader, testloader

    def update_weights(self, model, global_round):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                        momentum=0.5)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
                                         weight_decay=1e-4)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()

                if self.args.verbose and (batch_idx % 10 == 0):
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        """ Returns the inference accuracy and loss.
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)

            # Inference
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct/total
        return accuracy, loss

def test_inference_M2(args, model2, test_dataset):
    """ Returns the test accuracy and loss.
    """
    Aacc= torch.zeros(args.num_users)
    Aloss= torch.zeros(args.num_users)
    for users in range(args.num_users):

        idxs = np.arange(5000)
        mod = users % 5
        p1 = mod*2+1
        p2 = 10-p1
        idx1 = np.random.choice(idxs, p1*500, replace = False)
        idx2 = np.random.choice(idxs, p2*500, replace = False) + 5000
        idx_test = np.concatenate(
                    (idx1, idx2), axis=0)
        idxs_test = list(idx_test)
        model2[users].eval()
        loss, total, correct = 0.0, 0.0, 0.0

        device = 'cuda:1' if args.gpu else 'cpu'
        criterion = nn.NLLLoss().to(device)
        testloader = DataLoader(DatasetSplit(test_dataset, idxs_test), batch_size=128,
                                shuffle=False, drop_last = True)

        for batch_idx, (images, labels) in enumerate(testloader):
            images, labels = images.to(device), labels.to(device)

            # Inference
            outputs = model2[users](images)
            batch_loss = criterion(outputs, labels)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)
        accuracy = correct/total
        Aacc[users] = accuracy
        Aloss[users] = loss

    return Aacc, Aloss

def test_inference_M3(args, model2, test_dataset, users, ta, tl, epoch):
    """ Returns the test accuracy and loss.
    """
    Aacc= torch.zeros(args.num_users)
    Aloss= torch.zeros(args.num_users)
    idxs = np.arange(5000)
    mod = users % 5
    p1 = mod*2+1
    p2 = 10-p1
    idx1 = np.random.choice(idxs, p1*500, replace = False)
    idx2 = np.random.choice(idxs, p2*500, replace = False) + 5000
    idx_test = np.concatenate(
                    (idx1, idx2), axis=0)
    idxs_test = list(idx_test)
    model2.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    device = 'cuda:1' if args.gpu else 'cpu'
    criterion = nn.NLLLoss().to(device)
    testloader = DataLoader(DatasetSplit(test_dataset, idxs_test), batch_size=128,
                            shuffle=False)

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)

        # Inference
        outputs = model2(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)
    accuracy = correct/total
    ta[users, epoch] = accuracy
    tl[users, epoch] = loss

    return ta, tl


def test_inference_UB2_M(args, model2, test_dataset, idxs):
    """ Returns the test accuracy and loss.
    """
    Aacc= torch.zeros(args.num_users)
    Aloss= torch.zeros(args.num_users)
    for users in range(args.num_users):

        
        model2[users][0].eval()
        loss, total, correct = 0.0, 0.0, 0.0

        device = 'cuda:1' if args.gpu else 'cpu'
        criterion = nn.NLLLoss().to(device)
        testloader = DataLoader(DatasetSplit(test_dataset, idxs[users%5]), batch_size=128,
                                shuffle=False, drop_last = True)

        for batch_idx, (images, labels) in enumerate(testloader):
            images, labels = images.to(device), labels.to(device)

            # Inference
            outputs = model2[users][0](images)
            batch_loss = criterion(outputs, labels)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)
        accuracy = correct/total
        Aacc[users] = accuracy
        Aloss[users] = loss

    return Aacc, Aloss

def test_inference_UB3(args, model2, test_dataset, users, ta, tl, epoch, idxs_test):
    """ Returns the test accuracy and loss.
    """
    Aacc= torch.zeros(args.num_users)
    Aloss= torch.zeros(args.num_users)
    model2.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    device = 'cuda' if args.gpu else 'cpu'
    criterion = nn.NLLLoss().to(device)
    testloader = DataLoader(DatasetSplit(test_dataset, idxs_test), batch_size=128,
                            shuffle=False)

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)

        # Inference
        outputs = model2(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)
    accuracy = correct/total
    ta[users, epoch] = accuracy
    tl[users, epoch] = loss

    return ta, tl
