

"""
目的:
  - 2つの JSONL ファイル (VAL と INFERENCE) を 1行ずつ対応付けて読み、
    それぞれの <think> / <answer> を抽出して比較・集計し、CSV と要約を出力する。

想定する入力:
  1) VAL_JSONL_PATH:
     各行が JSON オブジェクト。少なくとも "output": {"content": "..."} を含む。
     content 文字列の中に <think> … </think> と <answer> … </answer> が含まれる前提。

     例:
       {"output": {"content": "<think>1+1 を計算する…</think><answer>2</answer>"}}

  2) INFERENCE_JSONL_PATH:
     各行が JSON オブジェクト。少なくとも "model_output": "..." を含む。
     任意で "is_correct": true/false を含む（無ければ False とみなす）。
     model_output にも <think> … </think> と <answer> … </answer> が入る想定。

     例:
       {"model_output": "<think>足し算…</think><answer>2</answer>", "is_correct": false}

出力:
  1) OUTPUT_CSV_PATH:
     列: id, think, answer, pred_trace, pred, is_correct, pred_corrected, diff
       - id: 行番号(1始まり)
       - think: VAL 側の think 全文
       - answer: VAL 側の answer
       - pred_trace: INFERENCE 側の think 全文
       - pred: INFERENCE 側の answer
       - is_correct: 最終的な正誤 (軽い再検証後の値)
       - pred_corrected: 再検証で正解化した場合の理由 ("MISSING_SINGLE_QUOTE" or "EVAL_MATCH" など)
       - diff: think の差分 (unified diff 文字列) または “NO_DIFFERENCE”/“MISSING_INFERENCE_TRACE”

  2) SUMMARY_OUTPUT_PATH (本スクリプト中では inference_outputs_path という変数名):
     - 次の3指標を集計表示:
       Missing Inference（INFERENCE の <think> 欠落数）
       Before inference_outputs Correct（再検証前の正解数）
       After inference_outputs Correct（再検証後の正解数）
"""

import sys 
import json 
import re 
import difflib 
import csv 
import os 

def extract_trace (text ):
    """
    文字列 text から <think>...</think> を正規表現で取り出し、行配列で返す。
    - DOTALL を使うことで改行を含む '...' にマッチさせる。
    - 無ければ空配列を返す。
    """
    match =re .search (r"<think>(.*?)</think>",text ,re .DOTALL )
    if match :
        return match .group (1 ).strip ().splitlines ()
    else :
        return []

def extract_answer (text ):
    """
    文字列 text から <answer>...</answer> を正規表現で取り出し、前後空白を削った文字列で返す。
    - 見つからなければ空文字を返す。
    """
    match =re .search (r"<answer>(.*?)</answer>",text ,re .DOTALL )
    if match :
        return match .group (1 ).strip ()
    else :
        return ""

def main ():


    if len (sys .argv )!=5 :
        print ("Usage: {} VAL_JSONL_PATH INFERENCE_JSONL_PATH OUTPUT_CSV_PATH SUMMARY_OUTPUT_PATH".format (sys .argv [0 ]))
        sys .exit (1 )

    val_path =sys .argv [1 ]
    inf_path =sys .argv [2 ]
    csv_path =sys .argv [3 ]
    inference_outputs_path =sys .argv [4 ]


    try :
        with open (val_path ,"r",encoding ="utf-8")as f_val :
            val_lines =f_val .readlines ()
    except Exception as e :

        print (f"Error: Failed to read VAL file '{val_path }': {e }",file =sys .stderr )
        sys .exit (1 )


    try :
        with open (inf_path ,"r",encoding ="utf-8")as f_inf :
            inf_lines =f_inf .readlines ()
    except Exception as e :
        print (f"Error: Failed to read INFERENCE file '{inf_path }': {e }",file =sys .stderr )
        sys .exit (1 )


    if len (val_lines )!=len (inf_lines ):
        print (f"Warning: VAL file has {len (val_lines )} lines while INFERENCE file has {len (inf_lines )} lines.",file =sys .stderr )
        print (f"Processing only {min (len (val_lines ),len (inf_lines ))} lines.",file =sys .stderr )


    try :
        os .makedirs (os .path .dirname (csv_path ),exist_ok =True )
    except Exception as e :
        print (f"Error: Failed to create directory for CSV file '{csv_path }': {e }",file =sys .stderr )
        sys .exit (1 )

    results =[]
    missing_inference_trace_count =0 
    parse_errors =0 


    num_correct_before_inference_outputs =0 
    num_correct_after_inference_outputs =0 


    for idx ,(val_line ,inf_line )in enumerate (zip (val_lines ,inf_lines ),start =1 ):

        try :
            val_obj =json .loads (val_line )
        except json .JSONDecodeError as e :
            print (f"Warning: Failed to parse JSON in VAL file at line {idx }: {e }",file =sys .stderr )
            parse_errors +=1 
            continue 


        try :
            inf_obj =json .loads (inf_line )
        except json .JSONDecodeError as e :
            print (f"Warning: Failed to parse JSON in INFERENCE file at line {idx }: {e }",file =sys .stderr )
            parse_errors +=1 
            continue 


        val_content =val_obj .get ("output",{}).get ("content","")
        val_trace_lines =extract_trace (val_content )
        val_answer =extract_answer (val_content )
        val_think_str ="\n".join (val_trace_lines )


        inf_content =inf_obj .get ("model_output","")
        inf_is_correct =inf_obj .get ("is_correct",False )


        if inf_is_correct :
            num_correct_before_inference_outputs +=1 

        inf_trace_lines =extract_trace (inf_content )
        inf_answer =str (extract_answer (inf_content ))
        generated_think_str ="\n".join (inf_trace_lines )
        pred_corrected_text =""







        if not inf_is_correct :
            try :


                if f"'{inf_answer }'"==str (val_answer ):
                    pred_corrected_text ="MISSING_SINGLE_QUOTE"
                    inf_is_correct =True 


                elif eval (str (inf_answer ))==eval (str (val_answer )):
                    pred_corrected_text ="EVAL_MATCH"
                    inf_is_correct =True 

            except Exception :

                pass 


        if inf_is_correct :
            num_correct_after_inference_outputs +=1 


        diff_text =""

        has_inference_trace =bool (re .search (r"<think>.*?</think>",inf_content ,re .DOTALL ))

        if not has_inference_trace and not inf_trace_lines :

            diff_text ="MISSING_INFERENCE_TRACE"
            missing_inference_trace_count +=1 
        elif not val_trace_lines and not inf_trace_lines :

            diff_text ="NO_DIFFERENCE (both empty)"
        elif val_trace_lines ==inf_trace_lines :

            diff_text ="NO_DIFFERENCE"
        else :

            diff_lines =list (difflib .unified_diff (
            val_trace_lines ,
            inf_trace_lines ,
            fromfile =f"VAL_trace (line {idx })",
            tofile =f"INFERENCE_trace (line {idx })",
            lineterm =""
            ))
            if diff_lines :
                diff_text ="\n".join (diff_lines )
            else :

                diff_text ="NO_DIFFERENCE (unexpected)"


        results .append ({
        "id":idx ,
        "think":val_think_str ,
        "answer":val_answer ,
        "pred_trace":generated_think_str ,
        "pred":inf_answer ,
        "is_correct":inf_is_correct ,
        "pred_corrected":pred_corrected_text ,
        "diff":diff_text ,
        })


    try :
        with open (csv_path ,"w",newline ='',encoding ="utf-8")as csvfile :
            fieldnames =["id","think","answer","pred_trace","pred","is_correct","pred_corrected","diff"]
            writer =csv .DictWriter (csvfile ,fieldnames =fieldnames )
            writer .writeheader ()
            writer .writerows (results )
    except Exception as e :
        print (f"Error: Failed to write CSV file '{csv_path }': {e }",file =sys .stderr )
        sys .exit (1 )



    total_results =len (results )
    if total_results >0 :
        missing_inference_perc =(missing_inference_trace_count /total_results )*100 
        before_reval_perc =(num_correct_before_inference_outputs /total_results )*100 
        after_reval_perc =(num_correct_after_inference_outputs /total_results )*100 

        inference_outputs_output =(
        f"Missing Inference: {missing_inference_trace_count }/{total_results } ({missing_inference_perc }%)\n"
        f"Before inference_outputs Correct: {num_correct_before_inference_outputs }/{total_results } ({before_reval_perc }%)\n"
        f"After inference_outputs Correct: {num_correct_after_inference_outputs }/{total_results } ({after_reval_perc }%)"
        )
    else :
        inference_outputs_output ="No results to revalidate."


    try :
        with open (inference_outputs_path ,"w",encoding ="utf-8")as f_inference_outputs :
            f_inference_outputs .write (inference_outputs_output +"\n")
        print (f"Inference outputs successfully written to {inference_outputs_path }")
    except Exception as e :
        print (f"Error: Failed to write inference_outputs file '{inference_outputs_path }': {e }",file =sys .stderr )
        sys .exit (1 )


    if parse_errors >0 :
        print (f"Warning: Skipped {parse_errors } lines due to JSON parsing errors.",file =sys .stderr )

if __name__ =="__main__":
    main ()
