import os
import json
import argparse
from refinement.classifier import ScenarioClassifier
from refinement.refiner import PromptRefiner
from refinement.scenarios import Scenario
from verification.verify import verify
from config import CFG_PR, CFG_CHUNK

def load_prompts(input_path: str) -> list:
    """Load prompts from a text file, return as a list of non-empty lines."""
    with open(input_path, "r", encoding="utf-8") as f:
        return [line.strip() for line in f if line.strip()]

def load_checkpoint(checkpoint_path: str) -> int:
    """Return the last processed index from the checkpoint file."""
    if not os.path.exists(checkpoint_path):
        return -1
    last_idx = -1
    with open(checkpoint_path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line.strip())
            idx = rec.get("index")
            if idx is not None:
                last_idx = max(last_idx, idx)
    return last_idx

def append_checkpoint(fp, index: int, refined: str):
    """Append a single processed prompt to the checkpoint file."""
    rec = {"index": index, "refined": refined}
    fp.write(json.dumps(rec, ensure_ascii=False) + "\n")
    fp.flush()
    os.fsync(fp.fileno())

def run_batch(input_path: str, output_path: str, checkpoint_path: str):
    """Process prompts: classify, refine, verify, and save results with checkpoint support."""
    prompts = load_prompts(input_path)
    total = len(prompts)

    clf = ScenarioClassifier()
    refiner = PromptRefiner(use_refine_llm=CFG_PR.use_refine_llm)
    chunk_cfg = CFG_CHUNK

    # Load checkpoint to determine where to resume
    last_idx = load_checkpoint(checkpoint_path)
    start_idx = last_idx + 1

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # Open output and checkpoint files for streaming writes
    with open(output_path, "a", encoding="utf-8") as out_fp, \
         open(checkpoint_path, "a", encoding="utf-8") as cfp:

        for idx in range(start_idx, total):
            original = prompts[idx]
            print(f"\n=== Processing line {idx+1}/{total}: {original}")

            # Scenario classification
            pred = clf.predict(original)
            scenario = Scenario.NORMAL
            for s in Scenario:
                if s.value == pred.label:
                    scenario = s
                    break
            print(f"[Scenario] {scenario} (p={pred.score:.2f}) | source={pred.source} | stats={pred.stats}")

            # Prompt refinement
            refined_prompt = refiner.refine(scenario, original)
            print(f"[Refined] {refined_prompt}")

            # Verification
            verify_results = verify(
                original_prompt=original,
                expanded_prompt=refined_prompt,
                chunk_cfg=chunk_cfg
            )
            print("[Verification Results]", verify_results)

            # Write refined prompt to output file
            out_fp.write(refined_prompt + "\n")
            out_fp.flush()
            os.fsync(out_fp.fileno())

            # Write checkpoint
            append_checkpoint(cfp, idx, refined_prompt)

    print(f"\nAll prompts processed. Refined prompts saved to {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Batch refine prompts with checkpoint support")
    parser.add_argument("--input", type=str, default="data/prompt700.txt", help="Path to the input prompt file")
    parser.add_argument("--output", type=str, default="data/prompt700_refined.txt", help="Path to save the refined prompts")
    parser.add_argument("--checkpoint", type=str, default="data/prompt700_refined.checkpoint.jsonl", help="Path to checkpoint file for resuming")
    args = parser.parse_args()

    run_batch(args.input, args.output, args.checkpoint)
