from collections import defaultdict

from .gene_evaluate import gene_evaluate
from .image_evaluate import image_evaluate

from utils.data_utils import move_tensors

def evaluate(model, datasets, data_name="image", **kwargs):
    if data_name == "gene":
        return gene_evaluate(model, datasets, **kwargs)
    elif data_name in ("celebA", "openBHB", "morphomnist"):
        return image_evaluate(model, datasets, **kwargs)
    elif data_name == "pendulum":
        return image_evaluate(model, datasets, save_orig=False, **kwargs)
    else:
        raise ValueError("name not recognized")

def evaluate_loss(model, datasets, name=None, **kwargs):
    model.eval()
    epoch_eval_stats = defaultdict(float)
    for batch_idx, batch in enumerate(datasets["test_loader"]):
        _, minibatch_eval_stats = model.loss(
            move_tensors(*batch, device=model.device), batch_idx
        )

        for key, val in minibatch_eval_stats.items():
            epoch_eval_stats[key] += val

    for key, val in epoch_eval_stats.items():
        epoch_eval_stats[key] = val / len(datasets["test_loader"])
    model.train()
    return epoch_eval_stats, model.early_stopping(None)
