# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import editdistance
import numpy as np
from collections import defaultdict
from nltk.tokenize import RegexpTokenizer
from utils.utils import CONSTANTS,load_jsonl
from typing import FrozenSet
import keyword
import re
from utils.metrics import *

def compute_EM(target, predictions, passk):
    target_lines = [line.strip() for line in target.splitlines() if line.strip()]
    EM_scores = []
    for prediction in predictions[:passk]:
        prediction_lines = [line.strip() for line in prediction.splitlines() if line.strip()][:len(target_lines)]
        if len(target_lines) != len(prediction_lines):
            EM_scores.append(0)
            continue
        if target_lines == prediction_lines:
            EM_scores.append(1)
            continue
        EM_scores.append(0)
    return sum(EM_scores)


def compute_ES(target, predictions, passk):
    target_lines = [line.strip() for line in target.splitlines() if line.strip()]
    target_str = '\n'.join(target_lines)
    ES_scores = []
    for prediction in predictions[:passk]:
        prediction_lines = [line.strip() for line in prediction.splitlines() if line.strip()][:len(target_lines)]
        prediction_str = '\n'.join(prediction_lines)
        ES_scores.append(
            1 - (editdistance.eval(target_str, prediction_str) / max(len(target_str), len(prediction_str)))
        )
    return max(ES_scores)

def get_language_keywords() -> FrozenSet[str]:
    return frozenset(k for k in keyword.kwlist if k != 'True' and k != 'False')


def is_identifier(token, language="java"):
    return True if IDENTIFIER_REGEX.match(token) \
                   and (language is None or token not in get_language_keywords()) else False


def extract_identifiers(source_code, language="java"):
    # the main idea is to remove String from a source code
    # then, tokenize the code to get all words and match with identifier regular expression
    # check if it is a language specific keyword, it not, then it is an identifier
    source_code_without_strings = re.sub(string_pattern, '', source_code)
    _ids = [t for t in code_tokenizer.tokenize(source_code_without_strings) if is_identifier(t, language=language)]
    return _ids


def compute_id_match(pred_ids, target_ids):
    pred_ids = list(set(pred_ids))
    target_ids = list(set(target_ids))
    tp = 0
    fp = 0
    fn = 0
    for pid in pred_ids:
        if pid in target_ids:
            tp += 1
        else:
            fp += 1
    for tid in target_ids:
        if tid not in pred_ids:
            fn += 1
    return tp, fp, fn


def compute_identifier_match( target,predictions, passk,language="java"):

    comment_prefix = ""
    if language == "python":
        comment_prefix = "#"
    elif language == "java":
        comment_prefix = "//"

    target_lines = [line.strip() for line in target.splitlines() if line.strip()]
    target_lines = [line for line in target_lines if not line.startswith(comment_prefix)]
    target_lines_str = "".join(target_lines)
    gt_ids = extract_identifiers(target_lines_str, language=language)

    identifier_em_scores = []
    identifier_f1_scores = []

    for prediction in predictions[:passk]:
        prediction_lines = [line.strip() for line in prediction.splitlines() if line.strip()]
        prediction_lines = [line for line in prediction_lines if not line.startswith(comment_prefix)][:len(target_lines)]
        prediction_lines_str = "".join(prediction_lines)
        pred_ids = extract_identifiers(prediction_lines_str, language=language)
        identifier_em = int(pred_ids == gt_ids)
        id_tp, id_fp, id_fn = compute_id_match(pred_ids, gt_ids)
        id_f1 = 2 * id_tp / (2 * id_tp + id_fp + id_fn) if (2 * id_tp + id_fp + id_fn) != 0 else 0
        identifier_em_scores.append(identifier_em)
        identifier_f1_scores.append(id_f1)

    return max(identifier_em_scores), max(identifier_f1_scores)


def compute_score_by_repo_with_metadata(repos, lines, stype, passk=1):
    scores = defaultdict(list)
    id_em_scores = defaultdict(list)
    id_f1_scores = defaultdict(list)
    for line in lines:
        repo = line['metadata']['task_id'].split('/')[0]
        if repo not in repos:
            continue
        samples = [line['generate_response']]
        if stype == 'EM':
            score = compute_EM(line['metadata']['ground_truth'], samples, passk)
            scores[repo].append(score)
        elif stype == 'ES':
            score = compute_ES(line['metadata']['ground_truth'], samples, passk)
            scores[repo].append(score)
        elif stype == 'ID_match':
            id_em_score,id_f1_score = compute_identifier_match(line['metadata']['ground_truth'], samples,passk)
            id_em_scores[repo].append(id_em_score)
            id_f1_scores[repo].append(id_f1_score)

    if stype == 'EM' or stype == 'ES':
        avg_scores = {repo: round(sum(scores[repo]) / len(scores[repo]), 4) for repo in scores}
        repo_count = {repo: len(scores[repo]) for repo in scores}
        print(stype)
        for repo in avg_scores.keys():
            print(f'{avg_scores[repo]}\t{repo_count[repo]}\t{repo}')

    elif stype == 'ID_match':
        id_em_avg_scores = {repo: round(sum(id_em_scores[repo]) / len(id_em_scores[repo]), 4) for repo in id_em_scores}
        id_em_repo_count = {repo: len(id_em_scores[repo]) for repo in id_em_scores}
        id_f1_avg_scores = {repo: round(sum(id_f1_scores[repo]) / len(id_f1_scores[repo]), 4) for repo in id_f1_scores}
        id_f1_repo_count = {repo: len(id_f1_scores[repo]) for repo in id_f1_scores}
        print(stype)
        for repo in id_em_repo_count.keys():
            print('id_em '+f'{id_em_avg_scores[repo]}\t{id_em_repo_count[repo]}\t{repo}')
        for repo in id_f1_repo_count.keys():
            print('id_f1 '+f'{id_f1_avg_scores[repo]}\t{id_f1_repo_count[repo]}\t{repo}')


if __name__ == '__main__':

    file_path = './generation_results/deepseek-coder/line_level.python.coarse2fine.10.0.retrieval.deepseek-coder.gen_res.jsonl'
    ground_truth_file_path='./graph_based_query/line_level.python.test.jsonl' #原来是这个

    em=compute_batch_EM(ground_truth_file_path,file_path, language="python")
    es=compute_batch_ES(ground_truth_file_path,file_path, language="python")
    id_em, id_f1=compute_bath_identifier_match(ground_truth_file_path, file_path, language="python")
    print('em:'+f'{em}')
    print('es:' + f'{es}')
    print('id_em:' + f'{id_em}')
    print('id_f1:' + f'{id_f1}')

    compute_score_by_repo_with_metadata(CONSTANTS.repos, load_jsonl(file_path), 'EM', passk=1)
    compute_score_by_repo_with_metadata(CONSTANTS.repos, load_jsonl(file_path), 'ES', passk=1)
    compute_score_by_repo_with_metadata(CONSTANTS.repos, load_jsonl(file_path), 'ID_match', passk=1)

