import os 
import json 
import jsonlines

import argparse 
from tqdm import tqdm, trange
from subprocess import PIPE, Popen, TimeoutExpired
import tempfile
import re 
from pathlib import Path
import signal
from ming.utils import client
from collections import defaultdict
from sympy import sympify
try:
    from rouge import Rouge
except ImportError:
    import subprocess
    import sys
    
    # 使用 subprocess 执行 pip 安装
    subprocess.check_call([sys.executable, "-m", "pip", "install", "rouge"])
    
    # 重新尝试导入
    try:
        from rouge import Rouge
    except ImportError:
        print("无法安装或找到rouge包。请手动安装。")
        sys.exit(1)


def normalize_frac(x):
    # Pattern to match \frac{a}{b}
    pattern = r'\\frac\{([^\}]+)\}\{([^\}]+)\}'
    
    # Search for the pattern in the input string
    match = re.search(pattern, x)
    
    # If a match is found, extract 'a' and 'b'
    if match:
        a = match.group(1)  # Numerator
        b = match.group(2)  # Denominator
        
        # Convert to a simplified form, if necessary
        # For demonstration, just return the extracted parts
        return a, b
    else:
        # import pdb 
        # pdb.set_trace()
        return None

def normalize_dfrac(x):
    pattern = r'\\dfrac\{([^\}]+)\}\{([^\}]+)\}'
    
    # Search for the pattern in the input string
    match = re.search(pattern, x)
    
    # If a match is found, extract 'a' and 'b'
    if match:
        a = match.group(1)  # Numerator
        b = match.group(2)  # Denominator
        
        # Convert to a simplified form, if necessary
        # For demonstration, just return the extracted parts
        return a, b
    else:
        # import pdb 
        # pdb.set_trace()
        return None

def normalize(x):
    if "\\frac" in x and normalize_frac(x):
        a, b = normalize_frac(x)
        try:
            a = float(a)
            b = float(b)
            return a / b
        except:
            return x
        
    elif "\\dfrac" in x and normalize_dfrac(x):
        a, b = normalize_dfrac(x)
        try:
            a = float(a)
            b = float(b)
            return a / b
        except:
            return x
    else:
        try:
            x = sympify(x).evalf()
            return float(x)
        except:
            return x

def acc(pred, target):
    return 1 if pred == target else 0

def rouge(pred, target):
    # compute rouge-1, rouge-2, rouge-l
    pass

def extract_bbox_content(s):
    contents = []
    i = 0
    while i < len(s):
        if s[i:i+7] == '\\boxed{':
            depth = 1
            start = i + 7
            i += 7
            while i < len(s) and depth > 0:
                if s[i] == '{':
                    depth += 1
                elif s[i] == '}':
                    depth -= 1
                    if depth == 0:
                        contents.append(s[start:i])
                i += 1
        else:
            i += 1
    return contents


def extract_answer_content(s):
    match = re.search(r'the answer is (.*?)\.', s, re.IGNORECASE)
    return match.group(1) if match else None


def math_acc(line):
    pred = line['text']
    target = line['additional_info']['solution']

    target_answer = extract_bbox_content(target)[0]
    pred_answer = extract_bbox_content(pred)

    # print(target)
    # print(target_answer)
    # print(pred)

    if pred_answer != []:
        pred_answer = pred_answer[0]
        target_answer = normalize(target_answer)
        if isinstance(target_answer, float):
            pred_answer = normalize(pred_answer) #if pred_answer is not None else float("-inf")

        if isinstance(target_answer, str) and isinstance(pred_answer, str): # target type = str
            return 1.0 if target_answer in pred_answer else 0.0
        
        elif isinstance(pred_answer, str): # target type = float
            return 1.0 if pred_answer in target else 0.0
        
        elif isinstance(pred_answer, float):
            if abs(pred_answer - target_answer) < 1e-3:
                return 1.0
            else:
                return 0.0
    else:
        if "." in pred:
            pred_answer = pred.split(".")[-1]
        else:
            pred_answer = pred[len(pred) // 2:]

        return 1.0 if target_answer in str(pred_answer) else 0.0


def code_acc(line):
    cwd = os.getcwd()
    text = line['text']

    # 示例字符串

    # 使用正则表达式匹配第一对```之间的内容
    match = re.search(r"```python(.*?)```", text, re.DOTALL)
    
    # 如果找到匹配项，则提取并打印
    if match:
        extracted_content = match.group(1)
    else:
        extracted_content = text

    additional_info = line['additional_info']
    # function_name = additional_info['function_name']
    test = additional_info['test']
    executable_code = extracted_content
    if isinstance(test, str):
        test_code = executable_code + "\n" + test
    else:
        test_code = executable_code + "\n" + "\n".join(test)
    
    if additional_info.get("entry_point", None) is not None:
        test_code = test_code + "\n\n" + f"check({additional_info['entry_point']})"
    

    with tempfile.TemporaryDirectory() as tempdir_name:
        tempdir_name = Path(tempdir_name)
        with open(tempdir_name / "program.py", "w", encoding="UTF-8") as f:
            f.write(test_code)
        os.chdir(tempdir_name)
        

    # idx = additional_info["id"]
    # with open(f"/remote-home/syjiang/repo/MING-MOE/logs/diverse/humaneval/tmp/{idx}", 'w') as f:
    #     f.write(test_code)
        
        p = Popen(f'python program.py', shell=True, stdout=PIPE, stderr=PIPE)
        time_limit = 15  # seconds
        scores = 1
        try:
            stdout, stderr = p.communicate(timeout=time_limit)
        except TimeoutExpired:
            # Linux
            # os.killpg(p.pid, signal.SIGTERM)
            # Windows
            os.system("kill {pid}".format(pid=p.pid))
            scores = 0
        else:
            if stderr:
                scores = 0
    

    os.chdir(cwd)
    return scores

def gsm8k_acc(line):
    # extract answer after #### 
    pred = line['text']
    target = line['additional_info']['answer']

    index = target.find("####")
    target_answer = target[index+4:].strip()
    

    pred_answer = extract_answer_content(pred)
    # import pdb
    # pdb.set_trace()
    # if index != -1:
    #     pred_answer = pred[index + 4:].strip()  # Extract answer after "####" and strip any leading or trailing whitespace
    # else:
    #     pred_answer = pred
    # index = target.find("####")
    # target_answer = target[index + 4:].strip()
    if pred_answer is not None:
        return 1 if target_answer in pred_answer else 0
    else:
        return 0

def sum_acc(line):
    pred = line['text']
    target = line['additional_info']['answer']
    rouge = Rouge()
    rouge_score = rouge.get_scores(pred, target)
    rouge_1 = rouge_score[0]['rouge-1']['r']
    rouge_2 = rouge_score[0]['rouge-2']['r']
    rouge_l = rouge_score[0]['rouge-l']['r']
    return rouge_1, rouge_2, rouge_l
    

def mmedbench_acc(line):
    pred = line['text']
    pred = re.findall(r'[A-E]', pred)[0]

    answer = line['additional_info']['answer_idx']

    return 1 if pred == answer else 0 

def bbh_acc(line):
    pred = line['text']
    answer = line['additional_info']['target']
    if "(" in pred and ")" in pred:
        # extract the content in () [maybe many], and select the one which is a single capital letter
        pred = re.findall(r'\((.*?)\)', pred)
        for p in pred:
            if p in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
                pred = f"({p})"
                break


    return 1 if answer in pred else 0

def apps_acc(line):
    text = line['text']
    match = re.search(r"```python(.*?)```", text)

    # 如果找到匹配项，则提取并打印
    if match:
        extracted_content = match.group(1)
    else:
        extracted_content = text
    additional_info = line['additional_info']
    input_output = additional_info['input_output']
    # try:
    #     input_output = json.loads(input_output)
    # except:
    #     return None
    # input_output = json.loads(input_output)

    inputs = input_output['inputs']
    outputs = input_output['outputs']
    test_code = extracted_content 
    assert len(inputs) == len(outputs)
    
    ff = tempfile.NamedTemporaryFile(mode='w')
    ff.write(test_code)
    name = ff.name 
    scores = 1
    for i in range(len(inputs)):
        cur_input = inputs[i]
        cur_output = outputs[i]
        
        p = Popen(f'python {name} < {cur_input}', shell=True, stdout=PIPE, stderr=PIPE)
        time_limit = 15  # seconds
        try:
            stdout, stderr = p.communicate(timeout=time_limit)
        except TimeoutExpired:
            # Linux
            # os.killpg(p.pid, signal.SIGTERM)
            # Windows
            # Popen("TASKKILL /F /PID {pid} /T".format(pid=p.pid))
            scores = 0
            break
        if stderr:
            scores = 0
            break
        if stdout.strip() != cur_output.strip():
            scores = 0
            break
    ff.close()
    return scores

def triviaqa_acc(line):
    pred = line['text']
    answers = line['additional_info']['answer']
    for answer in answers:
        if pred == answer:
            return 1 
    return 0


def mc_acc(line):
    pred = line['text']
    answer = line['additional_info']['answer']
    return 1 if pred == answer else 0

winogrande = mmlu_acc = arc_acc = cmmlu_acc = ceval_acc = mc_acc
    

METRIC_FUNC_MAPPING = {
    "math": math_acc,
    "math_500": math_acc,
    "humaneval": code_acc,
    "mbpp": code_acc,
    "gsm8k": gsm8k_acc,
    "mmedbench_en": mmedbench_acc,
    "mmedbench_zh": mmedbench_acc,
    "bbh": bbh_acc,
    "apps": apps_acc,
    "triviaqa": triviaqa_acc,
    "winogrande": winogrande,
    "mmlu": mmlu_acc,
    "arc": arc_acc,
    "cmmlu": cmmlu_acc,
    "ceval": ceval_acc,
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True)

    parser.add_argument("--output_file", type=str, required=False)
    args = parser.parse_args()

    # input_file is a jsonl file with the following format:

    questions = client.read_jsonl(args.input_file)
    # questions = [json.loads(q) for q in open(os.path.expanduser(args.input_file), "r")]
    
    total_num = len(questions)
    total_score = 0
    rouge_score = [[], [], []]
    if "merge" in args.input_file:
        dataset_name = args.input_file.split("/")[-3]
    else:
        dataset_name  = args.input_file.split("/")[-2].replace(".jsonl", "")
    acc_func = METRIC_FUNC_MAPPING[dataset_name]
    each_type_counts = defaultdict(int)
    for line in questions:
        each_type_counts[line['additional_info']['type']] += 1
    each_type_correct = defaultdict(int)
    wrong_idx = []
    for line in tqdm(questions, total=total_num):
        scores = acc_func(line)
        if isinstance(scores, tuple):
            rouge_1, rouge_2, rouge_l = scores
            rouge_score[0].append(rouge_1)
            rouge_score[1].append(rouge_2)
            rouge_score[2].append(rouge_l)
            continue
        if scores is None:
            total_num -= 1
            wrong_idx.append(line)
            continue
        total_score += scores
        if scores == 0:
            wrong_idx.append(line)
        each_type_correct[line['additional_info']['type']] += scores
        
    for k, v in each_type_correct.items():
        print(f"{k}\t{(v / each_type_counts[k]):.4f}")
    # compute the average score
    print(f"Total score\t{total_score / total_num:.4f}")
    if args.output_file:
        client.write_json(wrong_idx, args.output_file)
    # with open(args.output_file, 'w') as f:
    #     json.dump(wrong_idx, f, ensure_ascii=False)