from sklearn.metrics import f1_score, accuracy_score
import re
import string
import collections
import json
import sys

record_dev_gt_candidates = {}
record_dev_p = {}
record_dev_q = {}
record_dev_gt = {}
record_dev_gt_index = {}

with open('./FewGLUE_dev32/ReCoRD/val.jsonl') as out:
    for l in out:
        items = json.loads(l)
        passage = items['passage']['text']
        entities = items['passage']['entities']
        entities_token = [passage[entity['start']: entity['end'] + 1] for entity in entities]
        passage_id = items['idx']
        record_dev_p[str(passage_id)] = passage
        qas = items['qas']
        for q_as in qas:
            query = q_as['query']
            q_id = q_as['idx']
            answers = q_as['answers']
            answers_token = set(ans_span['text'] for ans_span in answers)
            id_key = '_'.join([str(passage_id), str(q_id)])
            record_dev_q[id_key] = query
            assert id_key not in record_dev_gt
            assert id_key not in record_dev_gt_index
            assert id_key not in record_dev_gt_candidates
            record_dev_gt[id_key] = []
            record_dev_gt_index[id_key] = []
            record_dev_gt_candidates[id_key] = []
            for candidate in entities_token:
                record_dev_gt_candidates[id_key].append(candidate)
            for ans_token in answers_token:
                record_dev_gt[id_key].append(ans_token)
                for i, e in enumerate(entities_token):
                    if e == ans_token:
                        record_dev_gt_index[id_key].append(i)


def exact_match(gt, pred, indexes):
    assert len(gt) == len(pred) and len(indexes) == len(gt)
    gt_ans = {}
    pred_ans = {}
    for gt_i, pred_i, index_i in zip(gt, pred, indexes):
        passage_id, q_id, ans_id = index_i.split('_')
        idx = '_'.join([passage_id, q_id])
        if idx not in gt_ans:
            gt_ans[idx] = [gt_i]
        else:
            gt_ans[idx].append(gt_i)

        if idx not in pred_ans:
            pred_ans[idx] = [pred_i]
        else:
            pred_ans[idx].append(pred_i)

    total = 0
    correct = 0

    for k in gt_ans.keys():
        total += 1
        group_gt = gt_ans[k]
        group_pred = pred_ans[k]
        if all([group_gt[i] == group_pred[i] for i in range(len(group_gt))]):
            correct += 1
    return 100 * correct / total


def wsc_accuracy(y_true, y_pred):
    total = 0
    correct = 0
    for gt, pred in zip(y_true, y_pred):
        total += 1
        if gt[0] == '*' and gt[-1] == '*':
            tag = gt[1:-1]
            if pred.lower() != tag.lower():
                correct += 1
        else:
            if pred.lower() == gt.lower():
                correct += 1
    return correct / total


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()


def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def record_eval(gt, pred, indexes):
    assert len(gt) == len(pred) and len(indexes) == len(gt)
    gt_ans = {}
    pred_ans = {}
    for gt_i, pred_i, index_i in zip(gt, pred, indexes):
        passage_id, q_id, ans_id = index_i.split('_')
        idx = '_'.join([passage_id, q_id])
        if idx not in gt_ans:
            gt_ans[idx] = [gt_i]
        else:
            gt_ans[idx].append(gt_i)

        if idx not in pred_ans:
            pred_ans[idx] = [pred_i]
        else:
            pred_ans[idx].append(pred_i)

    # EM
    total = 0
    correct = 0
    for k in gt_ans.keys():
        group_gt = gt_ans[k]
        group_pred = pred_ans[k]
        total += 1
        for a_p in group_pred:
            # total += 1
            if a_p in group_gt:
                correct += 1
                break
    EM = 100 * correct / total

    f1s = []
    # max-token-F1
    for k in gt_ans.keys():
        max_f1 = -1e10
        for a_pred in pred_ans[k]:
            for a_golden in gt_ans[k]:
                token_f1 = compute_f1(a_golden, a_pred)
                if token_f1 > max_f1:
                    max_f1 = token_f1
        f1s.append(max_f1)

    token_f1 = 100 * sum(f1s) / len(f1s)

    return EM, token_f1


task_metrics = {
    'boolq': [accuracy_score],
    'cb': [accuracy_score, f1_score],
    'copa': [accuracy_score],
    'rte': [accuracy_score],
    'multirc': [f1_score, exact_match],
    'wsc': [wsc_accuracy],
    'wic': [accuracy_score],
    'record': [record_eval]
}


def compute_metrics_record_evaluate(candidate, loss, data_index, output_path=None):
    preds = {}
    for candidate_i, loss_i, index_i in zip(candidate, loss, data_index):
        passage_id, q_id, ans_id = index_i.split('_')
        idx = '_'.join([passage_id, q_id])
        if idx not in preds:
            preds[idx] = [(candidate_i, loss_i)]
        else:
            preds[idx].append((candidate_i, loss_i))

    total = 0
    correct = 0
    # EM
    for k in record_dev_gt.keys():
        group_gt = record_dev_gt[k]
        max_pred = sorted(preds[k], key=lambda x: x[1])[0][0]
        total += 1
        if max_pred in group_gt:
            correct += 1
    EM = 100 * correct / total

    f1s = []
    # max-token-F1
    for k in record_dev_gt.keys():
        max_f1 = -1e10
        max_pred = sorted(preds[k], key=lambda x: x[1])[0][0]
        for a_golden in record_dev_gt[k]:
            token_f1 = compute_f1(a_golden, max_pred)
            if token_f1 > max_f1:
                max_f1 = token_f1
        f1s.append(max_f1)

    token_f1 = 100 * sum(f1s) / len(f1s)

    if output_path is not None:
        with open(output_path, 'w') as out:
            for k in record_dev_gt.keys():
                max_pred = sorted(preds[k], key=lambda x: x[1])[0][0]
                group_gt = record_dev_gt[k]
                passage_k = record_dev_p[k.split('_')[0]]
                query_k = record_dev_q[k]
                out.write(f'key: {k}, max pred: {max_pred}\n')
                out.write(str(group_gt) + '\n')
                out.write(passage_k + '\n')
                out.write(query_k + '\n')
                out.write('\n')

    return {'max token f1': token_f1, 'exact match': EM}


def compute_metrics_record_evaluate_classify(loss, data_index):
    preds = {}
    for loss_i, index_i in zip(loss, data_index):
        passage_id, q_id, ans_id = index_i.split('_')
        idx = '_'.join([passage_id, q_id])
        if idx not in preds:
            preds[idx] = [(ans_id, loss_i)]
        else:
            preds[idx].append((ans_id, loss_i))

    total = 0
    correct = 0
    for k in preds.keys():
        group_gt = record_dev_gt_index[k]
        max_pred_idx = int(sorted(preds[k], key=lambda x: x[1])[0][0])
        total += 1
        if max_pred_idx in group_gt:
            correct += 1
    EM = 100 * correct / total

    f1s = []
    # max-token-F1
    for k in preds.keys():
        max_f1 = -1e10
        max_pred = sorted(preds[k], key=lambda x: x[1])[0][0]
        assert len(preds[k]) == len(record_dev_gt_candidates[k])

        max_pred_entity = record_dev_gt_candidates[k][int(max_pred)]
        for a_golden in record_dev_gt[k]:
            token_f1 = compute_f1(a_golden, max_pred_entity)
            if token_f1 > max_f1:
                max_f1 = token_f1
        f1s.append(max_f1)
    token_f1 = 100 * sum(f1s) / len(f1s)
    return {'max token f1': token_f1, 'exact match': EM}


def compute_metrics(dataset_name, gt, pred, data_index=None):
    return_dict = {}
    for metric in task_metrics[dataset_name]:
        if metric.__name__ == 'f1_score':
            metric_value = metric(y_true=gt, y_pred=pred, average='macro') * 100
        elif metric.__name__ == 'exact_match':
            metric_value = metric(gt, pred, data_index)
        elif metric.__name__ == 'record_eval':
            max_token_f1, em = record_eval(gt, pred, data_index)
            return_dict['max token f1'] = max_token_f1
            return_dict['exact match'] = em
            break
        else:
            metric_value = metric(y_true=gt, y_pred=pred) * 100
        return_dict[metric.__name__] = metric_value
    return return_dict


def format_evaluate_result(evaluate_result_file, data_file, out_put_file, dataset):
    data_dict = {}
    with open(data_file, 'r') as data_f:
        for line in data_f:
            item = json.loads(line)
            data_dict[item['idx']] = item

    if dataset == 'wsc':
        with open(evaluate_result_file, 'r') as f:
            for line in f:
                line = line.strip().split('\t')
                input_x, gen_y, gt_y, idx = line
                if gt_y[0] == '*' and gt_y[-1] == '*':
                    tag = gt_y[1:-1]
                    if gen_y.lower() != tag.lower():
                        data_dict[int(idx)]['label'] = False
                    else:
                        data_dict[int(idx)]['label'] = True

                else:
                    if gen_y.lower() == gt_y.lower():
                        data_dict[int(idx)]['label'] = True
                    else:
                        data_dict[int(idx)]['label'] = False

    elif dataset == 'multirc':
        eval_out_dict = {}
        with open(evaluate_result_file, 'r') as f:
            for line in f:
                line = line.strip().split('\t')
                input_x, gen_y, gt_y, idx = line
                passage_id, q_id, ans_id = idx.split('_')
                passage_id, q_id, ans_id = int(passage_id), int(q_id), int(ans_id)
                if passage_id not in eval_out_dict:
                    eval_out_dict[passage_id] = {q_id: {ans_id: bool(gen_y)}}
                elif q_id not in eval_out_dict[passage_id]:
                    eval_out_dict[passage_id][q_id] = {ans_id: bool(gen_y)}
                else:
                    eval_out_dict[passage_id][q_id][ans_id] = bool(gen_y)
        for key in data_dict.keys():
            p_id = key
            for q in data_dict[key]['passage']['questions']:
                q_id = q['idx']
                for ans in q['answers']:
                    ans_id = ans['idx']
                    ans['label'] = 1 if eval_out_dict[p_id][q_id][ans_id] is True else 0

    else:
        with open(evaluate_result_file, 'r') as f:
            for line in f:
                line = line.strip().split('\t')
                input_x, gen_y, gt_y, idx = line
                if gen_y in ['True', 'False']:
                    gen_y = bool(gen_y)
                if gen_y =='true':
                    gen_y = True
                if gen_y == 'false':
                    gen_y = False
                data_dict[int(idx)]['label'] = gen_y

    with open(out_put_file, 'w') as f:
        for k, v in data_dict.items():
            f.write(json.dumps(v))
            f.write('\n')


if __name__ == '__main__':
    format_evaluate_result(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])