# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/models/test.py
# credit goes to: Paul Pu Liang

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @python: 3.6

import copy

import numpy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import time
from models.language_utils import get_word_emb_arr, repackage_hidden, process_x, process_y

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        d = int(self.idxs[item])
        image, label = self.dataset[d]
        return image, label

class DatasetSplit_leaf(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[item]
        return image, label

def test_img_local(net_g, net_trans, dataset, args,idx=None,indd=None, user_idx=-1, idxs=None):
    net_g.eval()
    net_trans.eval()
    test_loss = 0
    correct = 0

    data_loader = DataLoader(DatasetSplit(dataset,idxs), batch_size=args.local_bs,shuffle=False, drop_last=True)

    count = 0
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.to(args.device), target.to(args.device)

        feat_local = net_g(data)
        log_probs = net_trans(feat_local)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= count
    accuracy = 100.00 * float(correct) / count
    return accuracy, test_loss

def test_img_local_all(net, net_trans, args, dataset_test, dict_users_test,w_locals=None,w_glob_keys=None, indd=None,dataset_train=None,dict_users_train=None, return_all=False, color=False):
    tot = 0

    if color:
        start = args.num_users // 2
        end = args.num_users
    else:
        start = 0
        end = args.num_users // 2

    num_idxxs = args.num_users // 2
    acc_test_local = np.zeros(num_idxxs)
    num_test_local = np.zeros(num_idxxs)
    loss_test_local = np.zeros(num_idxxs)
    for idx in range(start, end):
        net_local = copy.deepcopy(net)
        net_trans = copy.deepcopy(net_trans)
        if w_locals is not None:
            w_local = net_local.state_dict()
            w_local_trans = net_trans.state_dict()
            for k in w_locals[idx].keys():
                if k in w_glob_keys:
                    w_local[k] = w_locals[idx][k]
                if k not in w_glob_keys:
                    w_local_trans[k] = w_locals[idx][k]
            net_local.load_state_dict(w_local)
            net_trans.load_state_dict(w_local_trans)
        net_local.eval()
        net_trans.eval()
        a, b = test_img_local(net_local, net_trans, dataset_test, args, user_idx=idx, idxs=dict_users_test[idx])
        tot += len(dict_users_test[idx])

        acc_test_local[idx] = a*len(dict_users_test[idx])
        num_test_local[idx] = a
        loss_test_local[idx] = b*len(dict_users_test[idx])

        del net_local

    if return_all:
        return acc_test_local, loss_test_local
    return sum(acc_test_local)/tot, sum(loss_test_local)/tot
