import json
from utils import read_json_file, RougeL
from arguments import get_args
import os


def main():
    args = get_args()

    # Load the generated data (answers.jsonl) - this is the original dataset we want to modify
    generate_data = read_json_file(args.generate_data_dir)
    
    # Load ground truth data for ROUGE-L calculation
    truth_data = read_json_file(args.truth_data_dir)

    # Ensure lengths match; truncate to minimum if needed
    n_gen = len(generate_data)
    n_truth = len(truth_data)
    if n_gen != n_truth:
        print(f"Warning: generate ({n_gen}) and truth ({n_truth}) lengths differ; truncating to min")
    n = min(n_gen, n_truth)

    # Add Rouge_L and preserve lm_loss for each instance
    for i in range(n):
        gen_text = generate_data[i].get('LLM_output', '')
        truth_text = truth_data[i].get('output', '')
        # Calculate ROUGE-L between generated output and ground truth
        rouge_l_score = RougeL(gen_text, truth_text)
        generate_data[i]['Rouge_L'] = rouge_l_score
        # keep existing lm_loss if present

    # Save the modified dataset with added metrics
    # Determine model name for output naming
    model_name = getattr(args, "model_name", None)
    if model_name is None:
        # try to infer from generate_data_dir (two levels up: .../<model_name>/<seed>/answers.jsonl or .../<model_name>/<rank>/answers.jsonl)
        try:
            p = os.path.abspath(args.generate_data_dir)
            model_name = os.path.basename(os.path.dirname(os.path.dirname(p)))
        except:
            model_name = "model"

    # Build output path under processed_data/gpt2/<model_name>/
    out_dir = args.save_dir or os.path.join(os.path.dirname(os.path.dirname(args.truth_data_dir)), model_name)
    os.makedirs(out_dir, exist_ok=True)
    filename = f"answers_with_metrics_{model_name}.new.jsonl"
    output_file = os.path.join(out_dir, filename)

    with open(output_file, 'w') as f:
        for i in range(n):
            f.write(json.dumps(generate_data[i], ensure_ascii=False) + "\n")

    print(f"Saved modified dataset with metrics to: {output_file}")

    
if __name__ == "__main__":
    main()