import jax
import jax.numpy as np
from jax import grad, jit, vmap

import time

def eval_ds_all(ds, params, state, net_apply, statistics_to_record, with_logits=False):
    logits, labels = get_logits(ds, params, state, net_apply)
    if with_logits:
        return [s(logits, labels) for s in statistics_to_record], logits
    else:
        return [s(logits, labels) for s in statistics_to_record]
        
def get_logits(ds, params, state, net_apply):
    all_logits = []
    all_labels = []
    for x, y in ds: 
        all_logits.append(net_apply(params, state, x)[0])
        all_labels.append(y)
    return np.concatenate(all_logits), np.concatenate(all_labels)

def get_labels(ds):
    all_labels = []
    for x, y in ds: 
        all_labels.append(y)
    return np.concatenate(all_labels)
