
import argparse 
import json 
import sys 
import re 
from pathlib import Path 
import shutil 

def parse_final_output (text :str )->str :
    pattern =r"answer>\s*(.*?)\s*</answer"
    matches =re .findall (pattern ,text ,re .DOTALL )
    if matches :
        return matches [-1 ].strip ()
    return ""

def normalize_quotes (s :str )->str :
    """
    シングルクォートや特殊な引用符をダブルクォートに置換し、
    さらに文字列が両端にダブルクォートで囲まれている場合はそれを削除します。
    """
    s =s .replace ("'","\"").strip ()
    if s .startswith ('"')and s .endswith ('"'):
        s =s [1 :-1 ].strip ()
    return s 


def evaluate_line (data :dict )->dict :
    """
    JSONオブジェクト（1行分）から model_output を再解析し、
    expected_final_output が抽出された model_final_output に含まれていれば
    is_correct を True に更新する関数です。

    このバージョンでは、引用符の違いだけの差異を無視するため、
    expected_final_output と model_final_output の両方を normalize_quotes 関数で正規化した上で比較します。
    """
    model_output =data .get ("model_output","")
    expected_final_output =data .get ("expected_final_output","")
    model_final_output =parse_final_output (model_output )
    data ["model_final_output"]=model_final_output 





    is_correct =False 
    model_final_output =model_final_output .encode ('unicode_escape').decode ('utf-8')
    if expected_final_output ==model_final_output :
        is_correct =True 
    elif expected_final_output =="'"+model_final_output +"'":
        is_correct =True 




    else :
        print (f"{expected_final_output ==model_final_output }")
        print (f"Expected: {expected_final_output }, type: {type (expected_final_output )}")
        print (f"Model   : {model_final_output }, type: {type (model_final_output )}")
        print ("-"*40 )
    data ["is_correct"]=is_correct 

    return data 

def process_file (file_path :str ):
    p =Path (file_path )
    if not p .exists ():
        print (f"[Error] ファイルが存在しません: {file_path }",file =sys .stderr )
        return 


    update_path =p .with_suffix (p .suffix +".update")
    updated_lines =[]
    correct_count =0 

    with p .open ("r",encoding ="utf-8")as f :
        for line_idx ,line in enumerate (f ,start =1 ):
            line =line .strip ()
            if not line :
                continue 
            try :
                data =json .loads (line )
            except json .JSONDecodeError :
                print (f"[Warning] 行 {line_idx } は有効なJSONではありません。スキップします。",file =sys .stderr )
                continue 
            updated_data =evaluate_line (data )
            if updated_data .get ("is_correct"):
                correct_count +=1 
            updated_lines .append (json .dumps (updated_data ,ensure_ascii =False ))

    with update_path .open ("w",encoding ="utf-8")as f :
        for line in updated_lines :
            f .write (line +"\n")
    print (f"[Info] New file created:\n {update_path }\n",file =sys .stderr )
    print (f"Is_correct True Count = {correct_count }")


def main ():
    parser =argparse .ArgumentParser (
    description ="既存のJSONLファイルをバックアップし、新しい parse_final_output により正解判定を再評価して is_correct を更新し、各ファイルごとの正解数を出力するスクリプト"
    )
    parser .add_argument ("jsonl_files",type =str ,nargs ="+",
    help ="絶対パスで指定するJSONLファイルのパス（複数指定可能）")
    args =parser .parse_args ()

    for file_path in args .jsonl_files :
        process_file (file_path )

if __name__ =="__main__":
    main ()
