import json
import os
import random
import argparse

def load_jsonl(file_path):
    data = []
    print(f"Reading {file_path}...")
    if not os.path.exists(file_path):
        print(f"Warning: {file_path} not found. Skipping.")
        return []
        
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError:
                    print(f"Skipping invalid JSON line in {file_path}")
    return data

def main(args):
    keep_data = load_jsonl(args.keep_path)
    refined_data = load_jsonl(args.refined_path)
    
    print(f"Loaded KEEP records: {len(keep_data)}")
    print(f"Loaded REFINED records (Total): {len(refined_data)}")
    
    refined_success_data = [item for item in refined_data if item.get('is_refined', False) is True]
    
    dropped_count = len(refined_data) - len(refined_success_data)
    print(f"-" * 40)
    print(f"Filtering Strategy: Strict (Discard Failed Refinements)")
    print(f"Dropped records: {dropped_count}")
    print(f"Retained REFINED records: {len(refined_success_data)}")
    print(f"-" * 40)

    all_data = keep_data + refined_success_data

    
    
    total_count = len(all_data)
    print(f"Total Combined records (D_final): {total_count}")

    random.seed(42) 
    random.shuffle(all_data)
    

    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    
    print(f"Writing to {args.output_path}...")
    with open(args.output_path, 'w', encoding='utf-8') as f:
        for item in all_data:
            # [Optional] Data cleaning: keep only training fields (instruction, input, output)
            # clean_item = {
            #     "instruction": item["instruction"],
            #     "input": item.get("input", ""),
            #     "output": item["output"]
            # }
            # f.write(json.dumps(clean_item, ensure_ascii=False) + "\n")
            
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    print("Merge Complete! 🎉")
    print(f"Final Dataset Ready: {args.output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--keep_path", type=str, default="alpaca/data/data_keep.jsonl")
    parser.add_argument("--refined_path", type=str, default="alpaca/data/data_refined.jsonl")
    parser.add_argument("--output_path", type=str, default="alpaca/data/cure-sft.jsonl")
    args = parser.parse_args()
    main(args)