def test_img_local(nets, dataset, args, idxs=None):

    for net in nets:
        net.eval()
    
    test_loss = 0
    correct = 0
    data_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=args.bs, shuffle=False, num_workers=args.num_workers)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        for j, net in enumerate(nets):
            logits = net(data)
            log_prob = logits - logits.logsumexp(1, keepdim=True)
            if j==0:
                log_probs_ = log_prob.detach()
            else:
                log_probs_ = torch.logsumexp(torch.stack([log_probs_, log_prob.detach()]), 0)
        log_probs_ -= np.log(len(nets))
        target = target.detach()
        test_loss += F.cross_entropy(log_probs_, target, reduction='sum').item()
        y_pred = log_probs_.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().sum()
        if idx==0:
            log_probs = log_probs_
        else:
            log_probs = torch.cat([log_probs, log_probs_], 0)
        if idx==0:
            targets = target.detach()
        else:
            targets = torch.cat([targets, target.detach()])
    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * float(correct) / len(data_loader.dataset)
    
    return accuracy, test_loss, log_probs.numpy(), targets.numpy()

def test_img_local_all(post_phi, args, dataset, dict_users):

    # sample MST samples
    nets = []
    m0_vec = weights2vec(post_phi['m0'], args.local_part)
    V0_vec = weights2vec(post_phi['V0'], args.local_part)
    st = MST(post_phi['n0']-post_phi['d']+1, m0_vec, ((post_phi['l0']+1)*V0_vec/(post_phi['l0']*(post_phi['n0']-post_phi['d']+1))).sqrt())
    samples = st.rsample([args.mst_nsamps,])
    for si in range(args.mst_nsamps):
        net = copy.deepcopy(post_phi['m0'])
        vec2weights(samples[si], net, args.local_part)
        nets.append(net)

    # evaluate each model on its own data
    acc_test_local = np.zeros(args.num_users)
    loss_test_local = np.zeros(args.num_users)
    logprobs_test_local = {}
    targets_test_local = {}
    for idx in range(args.num_users):
        acc, loss, logprobs, targets = test_img_local(nets, dataset, args, idxs=dict_users[idx])
        acc_test_local[idx] = acc
        loss_test_local[idx] = loss
        logprobs_test_local[idx] = logprobs
        targets_test_local[idx] = targets
    
    # figure out local data proportions
    data_ratio_local = np.zeros(args.num_users)
    for idx in range(args.num_users):
        idxs = dict_users[idx]
        data_ratio_local[idx] = len(DatasetSplit(dataset, idxs)) / len(dataset)

    return acc_test_local, loss_test_local, logprobs_test_local, targets_test_local, data_ratio_local

def single_test_img_local(net, dataset, args, idxs=None):

    net.eval()
    
    test_loss = 0
    correct = 0
    data_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=args.bs, shuffle=False)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        logits = net(data)
        log_probs_ = (logits - logits.logsumexp(1, keepdim=True)).detach()
        target = target.detach()
        test_loss += F.cross_entropy(log_probs_, target, reduction='sum').item()
        y_pred = log_probs_.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().sum()
        if idx==0:
            log_probs = log_probs_
        else:
            log_probs = torch.cat([log_probs, log_probs_], 0)
        if idx==0:
            targets = target.detach()
        else:
            targets = torch.cat([targets, target.detach()])
    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * float(correct) / len(data_loader.dataset)
        
    return accuracy, test_loss, log_probs.numpy(), targets.numpy()
