
import os 
import sys 
import re 
from collections import defaultdict 

ITER_RE =re .compile (r"^iter(\d+)(-guided)?\.txt$")

def aggregate_results_grouped_by_iter (
directory_path :str ,
target_dataset_name :str ,
iteration :str ,
use_guided_decoding :str ,
all_iters :bool =False ,
):
    """
    directory_path 内の各タスク配下 target_dataset_name/ にある
    iter*.txt を収集し、iterごと(かつ guided/unguidedごと)にグルーピングして集約。

    - iteration: "ALL" あるいは具体値(例: "7888")
    - use_guided_decoding: "true"/"false"（出力ファイル名に反映）
    - all_iters: Trueのとき、ディレクトリ内の全iterを集約（iteration引数は無視）
    """


    if use_guided_decoding .lower ()=="true":
        base_out =f"{target_dataset_name }_iter{iteration }_guided_grouped.txt"
    else :
        base_out =f"{target_dataset_name }_iter{iteration }_grouped.txt"


    if all_iters :
        base_out =f"{target_dataset_name }_ALLIters_grouped.txt"

    output_filepath =os .path .join (directory_path ,base_out )

    if not os .path .isdir (directory_path ):
        print (f"エラー: ディレクトリ '{directory_path }' が見つかりません。",file =sys .stderr )
        sys .exit (1 )

    try :
        task_dirs =[d for d in os .listdir (directory_path )
        if os .path .isdir (os .path .join (directory_path ,d ))]
    except OSError as e :
        print (f"エラー: ディレクトリ '{directory_path }' の読み込みに失敗しました: {e }",file =sys .stderr )
        sys .exit (1 )


    groups =defaultdict (lambda :defaultdict (list ))
    total_files =0 

    for task_name in sorted (task_dirs ):
        dataset_dir =os .path .join (directory_path ,task_name ,target_dataset_name )
        if not os .path .isdir (dataset_dir ):

            continue 

        try :
            for filename in os .listdir (dataset_dir ):
                m =ITER_RE .match (filename )
                if not m :
                    continue 

                iter_id =m .group (1 )
                guided =(m .group (2 )is not None )


                if not all_iters :
                    if iteration .upper ()!="ALL"and iter_id !=iteration :
                        continue 


                filepath =os .path .join (dataset_dir ,filename )
                try :
                    with open (filepath ,"r",encoding ="utf-8")as f :
                        content =f .read ().strip ()
                        if content :
                            groups [(iter_id ,guided )][task_name ].append (content )
                            total_files +=1 
                except Exception as e :
                    print (f"警告: '{filepath }' 読み込み中にエラー: {e }",file =sys .stderr )

        except OSError as e :
            print (f"警告: ディレクトリ '{dataset_dir }' のスキャン中にエラー: {e }",file =sys .stderr )

    if not groups :
        print (f"警告: '{directory_path }' 内で対象ファイルが見つかりません（dataset='{target_dataset_name }').")
        return 


    def sort_key (k ):
        iter_id ,guided =k 

        return (int (iter_id ),1 if guided else 0 )

    try :
        with open (output_filepath ,"w",encoding ="utf-8")as out :
            for (iter_id ,guided )in sorted (groups .keys (),key =sort_key ):
                guided_label ="guided"if guided else "unguided"
                out .write (f"=== Iter {iter_id } / {guided_label } ===\n")


                for task_name in sorted (groups [(iter_id ,guided )].keys ()):
                    out .write (f"Task: {task_name }\n")
                    for i ,content in enumerate (groups [(iter_id ,guided )][task_name ]):
                        out .write (content +"\n")

                        if i <len (groups [(iter_id ,guided )][task_name ])-1 :
                            out .write ("---\n")
                    out .write ("---\n")

                out .write ("\n")

        print (f"成功: {total_files }ファイルを '{output_filepath }' に集約（iterごとグルーピング）")

    except IOError as e :
        print (f"エラー: ファイル '{output_filepath }' の書き込みに失敗: {e }",file =sys .stderr )
        sys .exit (1 )


if __name__ =="__main__":
    if len (sys .argv )<6 :
        print ("使用法: python aggregate_script.py <ディレクトリパス> <データセット名> <モデル名> <イテレーション数|ALL> <use_guided_decoding> [--all-iters true|false]",file =sys .stderr )
        sys .exit (1 )

    target_directory =sys .argv [1 ]
    target_dataset_name =sys .argv [2 ]
    model_name =sys .argv [3 ]
    iteration =sys .argv [4 ]
    use_guided_decoding =sys .argv [5 ]
    all_iters =False 


    if len (sys .argv )>=8 and sys .argv [6 ]=="--all-iters":
        all_iters =(sys .argv [7 ].lower ()=="true")

    aggregate_results_grouped_by_iter (
    target_directory ,
    target_dataset_name ,
    iteration ,
    use_guided_decoding ,
    all_iters =all_iters ,
    )
