import re
import string
from collections import Counter, defaultdict
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score


def simple_accuracy(preds, labels):
    return (preds == labels).mean()


def acc_and_f1(preds, labels, average="binary"):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(y_true=labels, y_pred=preds, average=average)
    return {
        "acc": acc,
        "f1": f1,
        "acc_and_f1": (acc + f1) / 2,
    }


def pearson_and_spearman(preds, labels):
    pearson_corr = pearsonr(preds, labels)[0]
    spearman_corr = spearmanr(preds, labels)[0]
    return {
        "pearson": pearson_corr,
        "spearmanr": spearman_corr,
        "corr": (pearson_corr + spearman_corr) / 2,
    }


def compute_metrics(task_name, preds, labels, guids=None):
    assert len(preds) == len(labels)
    if task_name == "cola":
        return {"mcc": matthews_corrcoef(labels, preds)}
    elif task_name == "boolq":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "sst-2":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "qqp":
        return acc_and_f1(preds, labels)
    elif task_name == "yahooqa":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "yelp":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "event":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "argument":
        # return {"acc": simple_accuracy(preds, labels)}
        return acc_and_f1(preds, labels, average="macro")
    elif task_name == "dmarker":
        return acc_and_f1(preds, labels, average="macro")
    elif task_name == "qnli":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "rocstory":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "mnli":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "mnli-mm":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "scitail":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "drelation":
        return acc_and_f1(preds, labels, average="macro")
    elif task_name == "emotion":
        return acc_and_f1(preds, labels, average="macro")
    elif task_name == "agnews":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "dbpedia":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "amzn":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "yahooqa1":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "yahooqa2":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "yahooqa3":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "yahooqa4":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "yahooqa5":
        return {"acc": simple_accuracy(preds, labels)}
    else:
        raise KeyError(task_name)


task_metrics = {
    "cola": "mcc",
    "boolq": "acc",
    "sst-2": "acc",
    "qqp": "acc_and_f1",
    "yahooqa": "acc",
    "yelp": "acc",
    "event": "acc",
    "argument": "acc_and_f1",
    "dmarker": "acc_and_f1",
    "qnli": "acc",
    "rocstory": "acc",
    "mnli": "acc",
    "scitail": "acc",
    "drelation": "acc_and_f1",
    "emotion": "acc_and_f1",
    "agnews": "acc",
    "dbpedia": "acc",
    "amzn": "acc",
    "yahooqa1": "acc",
    "yahooqa2": "acc",
    "yahooqa3": "acc",
    "yahooqa4": "acc",
    "yahooqa5": "acc",
}
