def amr(pred, label):
    pred = pred.lower()
    label = label.lower()
    
    pred_dict = {}
    label_dict = {}
    for c in pred:
        if c not in pred_dict:
            pred_dict[c] = 1
        else:
            pred_dict[c] += 1

    for c in label:
        if c not in label_dict:
            label_dict[c] = 1
        else:
            label_dict[c] += 1

    AMR = 0
    for c in pred_dict:
        if c in label_dict:
            AMR += min(pred_dict[c], label_dict[c])
    assert AMR <= len(label)
    return AMR/len(label)

def eval_res(raw_res, label, AMR=False):
    raw_res = raw_res.lower()

    if "the answer is" in raw_res:
        # pred = raw_res.split("the answer is")[1].strip()
        # if pred[-1] == ".":
        #     pred = pred[:-1]

        pred = raw_res.split("the answer is")[1].strip(" \n'\"`")
        pred = pred[:len(label)]
        if AMR:
            return pred == label.lower(), amr(pred, label)
        else:
            return pred == label.lower()
    
    elif len(raw_res) == len(label):
        if AMR:
            return raw_res == label.lower(), amr(raw_res, label)
        else:
            return raw_res == label.lower()
    else:
        if AMR:
            return False, 0
        else:
            return False