from refinement.classifier import ScenarioClassifier
from refinement.refiner import PromptRefiner
from refinement.scenarios import Scenario
from verification.verify import verify
from config import CFG_PR, CFG_CHUNK
import os

def run_batch():
    input_path = os.path.join("data", "test_prompts.txt")
    output_path = os.path.join("data", "refined_prompts.txt")

    clf = ScenarioClassifier()
    refiner = PromptRefiner(use_refine_llm=CFG_PR.use_refine_llm)
    chunk_cfg = CFG_CHUNK

    with open(input_path, "r", encoding="utf-8") as f:
        prompts = [line.strip() for line in f if line.strip()]

    refined_list = []

    for idx, original in enumerate(prompts, start=1):
        print(f"\n=== Processing line {idx}: {original}")

        # classification
        pred = clf.predict(original)
        # print(f"[Scenario] {pred.label} (p={pred.score:.2f}) | source={pred.source} | stats={pred.stats}")

        scenario = Scenario.NORMAL
        for s in Scenario:
            if s.value == pred.label:
                scenario = s
                break

        print(f"[Scenario] {scenario} (p={pred.score:.2f}) | source={pred.source} | stats={pred.stats}")

        # prompt refinement
        refined_prompt = refiner.refine(scenario, original)
        print(f"[Refined] {refined_prompt}")

        # verification
        verify_results = verify(
            original_prompt=original,
            expanded_prompt=refined_prompt,
            chunk_cfg=chunk_cfg
        )
        print("[Verification Results]", verify_results)

        refined_list.append(refined_prompt)

    with open(output_path, "w", encoding="utf-8") as f:
        for rp in refined_list:
            f.write(rp + "\n")

    print(f"\nDone，Results saved in {output_path}")

if __name__ == "__main__":
    run_batch()
