from seqeval.metrics import f1_score as seqeval_f1_score
from seqeval.scheme import IOB2
from utils.conlleval import evaluate

def metric_fn(items, config):
    metrics = [item["metrics"] for item in items]
    if config["display_metric"] == "accuracy":
        correct_predictions = sum([metric["correct_predictions"] for metric in metrics])
        total = sum([metric["total"] for metric in metrics])
        accuracy = correct_predictions/total if total > 0 else 0
        loss = sum([metric["loss"] for metric in metrics])/len(metrics) if len(metrics) > 0 else 0

        composed_metric = {"loss": loss,
                           "accuracy": accuracy*100}

    elif config["display_metric"] == "F1" and config["model_type"] == "seq_label":
        loss = sum([metric["loss"] for metric in metrics]) / len(metrics) if len(metrics) > 0 else 0
        display_items = [item["display_items"] for item in items]
        all_predictions = []
        all_labels = []
        for item in display_items:
            all_predictions += item["predictions"]
            all_labels += item["labels"]

        composed_metric = {"loss": loss,
                           "F1": seqeval_f1_score(all_labels, all_predictions, scheme=IOB2)}

    return composed_metric


def compose_dev_metric(metrics, config):
    total_metric = 0
    n = len(metrics)
    for key in metrics:
        total_metric += metrics[key][config["save_by"]]
    return config["metric_direction"] * total_metric / n
