#!/usr/bin/env python3
"""
End-to-end geometry pipeline that wires together:

- ``LLMGenerator`` for problem generation
- ``LLMJudge`` for automatic quality filtering
- ``LLMPlotter`` + ``RealWorldPlotter`` for plotting code and rendered figures
- ``NumericalCheck`` for numeric/proof verification
- ``VLImageQuality`` for image QA + captioning
- ``VisualizeQA`` for caption-aware question/CoT rewriting.
"""

import json
import time
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, List, Optional
import argparse

from .LLMGenerator import LLMGenerator
from .LLMPlotter import LLMPlotter
from .LLMJudge import LLMJudge
from .Plotter import RealWorldPlotter
from ..utils.config import Config
from ..utils.model_urls import get_base_url_for_model
from ..utils import latex_to_float
from .NumericalCheck import NumericalCheck
from .VLImageQuality import VLImageQuality
from .VisualizeQA import VisualizeQA



class Pipeline:
    """High-level pipeline that runs generation → judging → plotting → checks."""
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        generator_model: Optional[str] = None,
        judge_model: Optional[str] = None,
        plotter_model: Optional[str] = None,
        vl_model: Optional[str] = None,
        visualize_qa_model: Optional[str] = None,
        base_url: Optional[str] = None,
        delay: Optional[float] = None,
        figures_dir: Optional[Path] = None,
        max_retries: Optional[int] = None,
        debug: bool = False,
    ):
        generator_model = generator_model or model
        judge_model = judge_model or model
        plotter_model = plotter_model or model
        vl_model = vl_model or model
        visualize_qa_model = visualize_qa_model or model
        # Automatically infer base_url from model name
        generator_base_url = get_base_url_for_model(generator_model)
        judge_base_url = get_base_url_for_model(judge_model)
        plotter_base_url = get_base_url_for_model(plotter_model)
        vl_base_url = get_base_url_for_model(vl_model)
        visualize_qa_base_url = get_base_url_for_model(generator_model)
        #print(generator_base_url)
        # If figures_dir is a string, convert it to a Path
        if isinstance(figures_dir, str):
            figures_dir = Path(figures_dir)
        
        self.generator = LLMGenerator(
            api_key=api_key,
            model=generator_model,
            base_url=generator_base_url,
            max_retries=max_retries,
        )

        self.judge = LLMJudge(
            api_key=api_key,
            model=judge_model,
            base_url=judge_base_url,
            max_retries=max_retries,
        )

        self.plotter = LLMPlotter(
            api_key=api_key,
            model=plotter_model,
            base_url=plotter_base_url,
            max_retries=max_retries,
        )

        self.visual_plotter = RealWorldPlotter()
        self.numerical_check = NumericalCheck()
        
        self.vl_image_quality = VLImageQuality(
            api_key=api_key,
            model=vl_model,
            base_url=vl_base_url,
            max_retries=max_retries,
        )
        
        # VisualizeQA uses generator_model (a text-only model)
        self.visualize_qa = VisualizeQA(
            api_key=api_key,
            model=generator_model,
            base_url=visualize_qa_base_url,
            max_retries=max_retries,
        )
        
        self.delay = delay
        self.figures_dir = figures_dir
        if self.figures_dir:
            self.figures_dir.mkdir(parents=True, exist_ok=True)
        
        self.debug = debug
        
        # Global error collection
        self.global_errors: List[Dict[str, Any]] = []
        self.errors_lock = threading.Lock()  # Protects shared error list
    
    @staticmethod
    def _serialize_usage(usage_obj) -> Optional[Dict[str, int]]:
        """
        Convert a usage object into a JSON-serializable dict.

        Args:
            usage_obj: Usage object (possibly a CompletionUsage instance or a dict).

        Returns:
            A dict with token counts, or None if conversion is not possible.
        """
        # TODO: what is actually the usage object?
        if usage_obj is None:
            return None
        
        # If it's already a dictionary, return it directly
        if isinstance(usage_obj, dict):
            return {
                "prompt_tokens": usage_obj.get("prompt_tokens", 0) or 0,
                "completion_tokens": usage_obj.get("completion_tokens", 0) or 0,
                "total_tokens": usage_obj.get("total_tokens", 0) or 0,
            }
        
        # If it's an object, extract the attributes
        prompt_tokens = getattr(usage_obj, 'prompt_tokens', None)
        completion_tokens = getattr(usage_obj, 'completion_tokens', None)
        total_tokens = getattr(usage_obj, 'total_tokens', None)
        
        # If the attributes exist, convert them into a dictionary
        if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
            return {
                "prompt_tokens": int(prompt_tokens) if prompt_tokens is not None else 0,
                "completion_tokens": int(completion_tokens) if completion_tokens is not None else 0,
                "total_tokens": int(total_tokens) if total_tokens is not None else 0,
            }
    
        return None
    
    def _calculate_final_tokens(self, usage_items: List[Any]) -> int:
        """Aggregate total token usage from multiple usage entries."""
        total = 0
        for usage in usage_items:
            usage_dict = self._serialize_usage(usage)
            if usage_dict:
                total += usage_dict.get("total_tokens", 0) or 0
        return total
    
    @staticmethod
    def _sleep(delay: Optional[float]) -> None:
        if delay and delay > 0:
            time.sleep(delay)
    
    @staticmethod
    def _build_problem_context(
        row_data: Dict[str, Any],
        prompt_payload: Optional[Dict[str, Any]],
    ) -> Dict[str, Optional[str]]:
        """Build a normalized problem context (conditions / conclusion / constructions)."""
        def pick_value(translated_key: str, original_key: str) -> Optional[str]:
            if prompt_payload:
                if prompt_payload.get(translated_key):
                    return prompt_payload.get(translated_key)
                if prompt_payload.get(original_key):
                    return prompt_payload.get(original_key)
            return row_data.get(original_key) or row_data.get(translated_key)
        
        context = {
            "conditions": pick_value("problem_translated", "problem_original") or row_data.get("llm_input_renamed"),
            "conclusion": pick_value("conclusion_translated", "conclusion_original"),
            "constructions": row_data.get("original_constructions")
                or (prompt_payload.get("aux_translated") if prompt_payload else None)
                or (prompt_payload.get("aux_original") if prompt_payload else None),
        }
        return context
    
    def _append_log_entry(
        self,
        entry: Dict[str, Any],
        result_file: Path,
        lock: Optional[threading.Lock] = None,
    ) -> None:
        """Append a single JSONL log entry to the given result file."""
        result_file.parent.mkdir(parents=True, exist_ok=True)
        line = json.dumps(entry, ensure_ascii=False)
        if lock:
            with lock:
                with open(result_file, "a", encoding="utf-8") as f:
                    f.write(line + "\n")
        else:
            with open(result_file, "a", encoding="utf-8") as f:
                f.write(line + "\n")
    
    def _save_errors(self, result_file: Path) -> None:
        """Persist all collected errors into a separate JSON file."""
        if not self.global_errors:
            print("No errors to save")
            return
        
        error_file = result_file.parent / f"{result_file.stem}_errors.json"
        
        with self.errors_lock:
            errors_to_save = self.global_errors.copy()
        
        try:
            with open(error_file, "w", encoding="utf-8") as f:
                json.dump(errors_to_save, f, ensure_ascii=False, indent=2)
            print(f"Error log saved: {error_file} (total {len(errors_to_save)} errors)")
        except Exception as e:
            print(f"Failed to save error log: {e}")
            traceback.print_exc()

    def process_one_item(
        self,
        index: int,
        row_data: Dict[str, Any],
        result_file: Path,
        lock: Optional[threading.Lock] = None,
    ) -> Dict[str, Any]:
        """
        Process a single sample and append its result to the JSONL result log.

        This runs:
        1) LLM generation
        2) LLM judging
        3) plotting-code generation + figure rendering
        4) numerical check
        5) VL image quality check + caption
        6) VisualizeQA caption-aware rewriting (optional)
        """
        
        entry: Dict[str, Any] = {
            "index": index,
            "problem_context": {
                "conditions": row_data.get("llm_input_renamed"),
                "conclusion": None,
                "constructions": row_data.get("original_constructions"),
            },
            "generation": {"status": "pending"},
            "validation": {"status": "pending"},
            "plotting": {"status": "pending"},
            "image_quality": {"status": "pending"},
            "numerical_check": {"status": "pending"},
            "visualize_qa": {"status": "pending"},
        }
        usage_records: List[Any] = []
        
        def debug_print(step_name: str, data: Dict[str, Any]) -> None:
            """Print detailed debug information for a given pipeline step."""
            if self.debug:
                print(f"\n{'='*80}")
                print(f"[DEBUG] {step_name} [sample {index}]")
                print(f"{'='*80}")
                print(json.dumps(data, ensure_ascii=False, indent=2))
                print(f"{'='*80}\n")
        
        def record_error(
            step: str,
            error: Exception,
            error_msg: Optional[str] = None,
            additional_info: Optional[Dict[str, Any]] = None,
        ) -> None:
            """Record an error into the global error list (thread-safe)."""
            error_info = {
                "index": index,
                "step": step,
                "error_type": type(error).__name__,
                "error_message": str(error) if error_msg is None else error_msg,
                "traceback": traceback.format_exc(),
                "timestamp": time.time(),
            }
            if additional_info:
                error_info["additional_info"] = additional_info
            
            with self.errors_lock:
                self.global_errors.append(error_info)
        
        def finalize(status: str, reason: Optional[str] = None) -> Dict[str, Any]:
            entry["status"] = status
            if reason:
                entry["status_reason"] = reason
            entry["token_usage"] = self._calculate_final_tokens(usage_records)
            self._append_log_entry(entry, result_file, lock=lock)
            return entry
        
        # Step 1: problem generation
        print(f"Step 1: generate problem [sample {index}]")
        # Get problem type from row_data or configuration
        problem_type = row_data.get("problem_type", "proof")
        
        if self.debug:
            debug_print("Step 1: generation - input", {
                "problem_type": problem_type,
                "row_data_keys": list(row_data.keys()),
                "llm_input_renamed": row_data.get("llm_input_renamed"),
                "original_constructions": row_data.get("original_constructions"),
            })
        
        gen_result = self.generator.generate(row_data, index=index, problem_type=problem_type)
        usage_records.append(gen_result.get("usage"))
        
        # Full generation result (debug only)
        debug_print("Step 1: generation - full result", {
            "success": gen_result.get("success"),
            "generated": gen_result.get("generated"),
            "llm_response": gen_result.get("llm_response"),
            "prompt_payload": gen_result.get("prompt_payload"),
            "prompt_text": gen_result.get("prompt_text"),
            "usage": self._serialize_usage(gen_result.get("usage")),
            "finish_reason": gen_result.get("finish_reason"),
            "call_details": gen_result.get("call_details"),
            "error": gen_result.get("error"),
        })
        
        if not gen_result.get("success"):
            error_msg = gen_result.get("error", "Problem generation failed")
            entry["generation"] = {"status": "failed", "error": error_msg}
            return finalize("failed", error_msg)
        
        generated = gen_result.get("generated", {})
        question = generated.get("question")
        cot = generated.get("cot")
        answer = generated.get("answer")
        entry["generation"] = {
            "status": "success",
            "question": question,
            "cot": cot,
            "answer": answer,
        }
        entry["problem_context"] = self._build_problem_context(row_data, gen_result.get("prompt_payload"))
        self._sleep(self.delay)
        
        if not question:
            error_msg = "Missing generated question text"
            entry["generation"]["status"] = "failed"
            entry["generation"]["error"] = error_msg
            return finalize("failed", error_msg)
        
        # Step 2: LLM judging
        if True:
            print(f"Step 2: judge problem [sample {index}]")
            judge_result = self.judge.judge(question, cot=cot or "", answer=answer or "", index=index)
            usage_records.append(judge_result.get("usage"))
            
            # Full judge result (debug only)
            debug_print("Step 2: judging - full result", {
                "success": judge_result.get("success"),
                "passed": judge_result.get("passed"),
                "reason": judge_result.get("reason"),
                "score": judge_result.get("score"),
                "llm_response": judge_result.get("llm_response"),
                "usage": self._serialize_usage(judge_result.get("usage")),
                "finish_reason": judge_result.get("finish_reason"),
                "error": judge_result.get("error"),
            })
            
            if not judge_result.get("success"):
                reason = judge_result.get("error", "Judging failed")
                entry["validation"] = {"status": "failed", "passed": False, "reason": reason, "score": 0}
                return finalize("failed", reason)
            
            passed = bool(judge_result.get("passed"))
            reason = judge_result.get("reason", "")
            entry["validation"] = {
                "status": "success" if passed else "failed",
                "passed": passed,
                "reason": reason,
                "score": judge_result.get("score"),
            }
            self._sleep(self.delay)
            if not passed:
                return finalize("failed", reason or "Problem did not pass judge")
        
        # Step 3: plotting-code generation and figure rendering
        print(f"Step 3: generate plotting code [sample {index}]")
        try:
            problem_type = row_data.get("problem_type", "proof")
            plot_result = self.plotter.plot(question, index=index, problem_type=problem_type)
            usage_records.append(plot_result.get("usage"))
            
            # Full plotting result (debug only)
            debug_print("Step 3: plotting - full result", {
                "success": plot_result.get("success"),
                "plotting_code": plot_result.get("plotting_code"),
                "code": plot_result.get("code"),
                "llm_response": plot_result.get("llm_response"),
                "usage": self._serialize_usage(plot_result.get("usage")),
                "finish_reason": plot_result.get("finish_reason"),
                "call_details": plot_result.get("call_details"),
                "error": plot_result.get("error"),
            })
            
            if not plot_result.get("success"):
                error_msg = plot_result.get("error", "Failed to generate plotting code")
                entry["plotting"] = {"status": "failed", "error": error_msg}
                return finalize("failed", error_msg)
            
            plotting_code = plot_result.get("plotting_code") or {}

            entry["plotting"] = {
                "status": "success",
                "code": plot_result.get("code"),
                "plotting_code": plotting_code,
                "actual_data": plot_result.get("result"),
            }
            self._sleep(self.delay)
            
            if not plotting_code:
                error_msg = "Missing plotting_code"
                entry["plotting"]["status"] = "failed"
                entry["plotting"]["error"] = error_msg
                return finalize("failed", error_msg)
            
            if self.figures_dir:
                figure_path = self.figures_dir / f"figure_{index}.png"
                try:
                    success, actual_plotting_data = self._draw_geometry(
                        plotting_code=plotting_code,
                        output_path=figure_path,
                        index=index,
                    )
                    if success:
                        entry["plotting"]["figure_path"] = str(figure_path)
                        entry["plotting"]["actual_data"] = actual_plotting_data
                        print(f"[sample {index}] Figure saved to: {figure_path}")
                    else:
                        error_msg = "Failed to render figure"
                        entry["plotting"]["status"] = "failed"
                        entry["plotting"]["error"] = error_msg
                        return finalize("failed", error_msg)
                except Exception as draw_exc:
                    error_msg = f"Exception while rendering geometry: {type(draw_exc).__name__}: {str(draw_exc)}"
                    print(f"[sample {index}] {error_msg}")
                    traceback.print_exc()
                    record_error("Step 3 (render geometry)", draw_exc, error_msg)
                    entry["plotting"]["status"] = "failed"
                    entry["plotting"]["error"] = error_msg
                    return finalize("failed", error_msg)
            else:
                print(f"[sample {index}] Figure output directory not set; skip rendering")
                figure_path = None
        except Exception as e:
            error_msg = f"Step 3 (plotting) raised an exception: {type(e).__name__}: {str(e)}"
            print(f"[sample {index}] {error_msg}")
            traceback.print_exc()
            record_error("Step 3 (plotting)", e, error_msg)
            entry["plotting"] = {"status": "failed", "error": error_msg}
            return finalize("failed", error_msg)
        
        # Step 4: numerical consistency check (for both proof and computation problems)
        problem_type = row_data.get("problem_type", "proof")
        print(f"Step 4: numerical check [sample {index}]")
        try:
            if self.debug:
                debug_print("Step 4: numerical check - input", {
                    "problem_type": problem_type,
                    "answer": answer if problem_type.lower() != "proof" else 0,
                    "plotting_code_keys": list(plotting_code.keys()) if plotting_code else [],
                    "quantities": plotting_code.get("quantities", []) if plotting_code else [],
                })
            
            if problem_type.lower() == "proof":
                if not plotting_code or not plotting_code.get("quantities"):
                    error_msg = "Missing quantities for proof-type numerical check"
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                try:
                    # For proof problems, quantities are "left - right" expressions; we expect 0.
                    proof_answer = 0
                    numerical_passed = bool(self.numerical_check.check(answer=proof_answer, meta_data=plotting_code))
                    
                    if self.debug:
                        debug_print("Step 4: numerical check (proof) - result", {
                            "answer": proof_answer,
                            "numerical_passed": numerical_passed,
                            "quantities": plotting_code.get("quantities", []),
                        })
                except Exception as exc:
                    error_msg = (
                        f"Step 4 (numerical check - proof) raised an exception: "
                        f"{type(exc).__name__}: {str(exc)}"
                    )
                    print(f"[sample {index}] {error_msg}")
                    print(f"[sample {index}] plotting_code: {plotting_code}")
                    traceback.print_exc()
                    record_error(
                        "Step 4 (numerical check - proof, internal)",
                        exc,
                        error_msg,
                        {"plotting_code": plotting_code}
                    )
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                if not numerical_passed:
                    error_msg = (
                        "Proof-type numerical check failed: equality does not hold "
                        "(quantities evaluation is not 0)."
                    )
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                entry["numerical_check"] = {"status": "success"}
            else:
                if not answer:
                    error_msg = "Missing answer; cannot run numerical check"
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                numeric_answer = latex_to_float(answer)
                if numeric_answer is None:
                    error_msg = f"Answer could not be parsed as a numeric value: {answer}"
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                try:
                    numerical_passed = bool(self.numerical_check.check(answer=numeric_answer, meta_data=plotting_code))
                    
                    if self.debug:
                        debug_print("Step 4: numerical check (computation) - result", {
                            "answer": answer,
                            "numerical_passed": numerical_passed,
                            "quantities": plotting_code.get("quantities", []),
                        })
                except Exception as exc:
                    error_msg = (
                        f"Step 4 (numerical check - computation) raised an exception: "
                        f"{type(exc).__name__}: {str(exc)}"
                    )
                    print(f"[sample {index}] {error_msg}")
                    print(f"[sample {index}] plotting_code: {plotting_code}")
                    print(f"[sample {index}] answer: {answer}")
                    traceback.print_exc()
                    record_error(
                        "Step 4 (numerical check - computation, internal)",
                        exc,
                        error_msg,
                        {"plotting_code": plotting_code, "answer": answer}
                    )
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                if not numerical_passed:
                    error_msg = "Numerical check failed"
                    entry["numerical_check"] = {"status": "failed", "error": error_msg}
                    return finalize("failed", error_msg)
                
                entry["numerical_check"] = {"status": "success"}
        except Exception as e:
            error_msg = f"Step 4 (numerical check) raised an exception: {type(e).__name__}: {str(e)}"
            print(f"[sample {index}] {error_msg}")
            traceback.print_exc()
            record_error("Step 4 (numerical check)", e, error_msg)
            entry["numerical_check"] = {"status": "failed", "error": error_msg}
            return finalize("failed", error_msg)
        

        # Step 5: image quality check (and caption generation)
        if figure_path and figure_path.exists():
            print(f"Step 5: image quality check [sample {index}]")
            try:
                quality_result = self.vl_image_quality.check_image_quality(
                    image_path=figure_path,
                    index=index,
                )
                usage_records.append(quality_result.get("usage"))
                
                debug_print("Step 5: image quality check - full result", {
                    "success": quality_result.get("success"),
                    "passed": quality_result.get("passed"),
                    "reason": quality_result.get("reason"),
                    "llm_response": quality_result.get("llm_response"),
                    "usage": self._serialize_usage(quality_result.get("usage")),
                    "finish_reason": quality_result.get("finish_reason"),
                    "error": quality_result.get("error"),
                })
                
                if not quality_result.get("success"):
                    error_msg = quality_result.get("error", "Image quality check failed")
                    entry["image_quality"] = {
                        "status": "failed",
                        "error": error_msg,
                    }
                    return finalize("failed", error_msg)
                
                passed = quality_result.get("passed", False)
                reason = quality_result.get("reason", "")
                
                entry["image_quality"] = {
                    "status": "success" if passed else "failed",
                    "passed": passed,
                    "reason": reason,
                }
                
                if not passed:
                    error_msg = (
                        "Image quality check did not pass: "
                        "the rendered figure is not suitable as a geometry illustration."
                    )
                    if reason:
                        error_msg += f" ({reason})"
                    return finalize("failed", error_msg)
                
                print(f"[sample {index}] Image quality check passed")
                
                self._sleep(self.delay)
                
                # Step 5.2: generate caption (based on image and plotting_code)
                print(f"Step 5.2: generate caption [sample {index}]")
                try:
                    quality_plotting_code = actual_plotting_data if 'actual_plotting_data' in locals() else plotting_code
                    
                    caption_result = self.vl_image_quality.generate_caption(
                        image_path=figure_path,
                        plotting_code=quality_plotting_code,
                        index=index,
                    )
                    usage_records.append(caption_result.get("usage"))
                    
                    debug_print("Step 5.2: caption generation - full result", {
                        "success": caption_result.get("success"),
                        "caption": caption_result.get("caption"),
                        "llm_response": caption_result.get("llm_response"),
                        "usage": self._serialize_usage(caption_result.get("usage")),
                        "finish_reason": caption_result.get("finish_reason"),
                        "error": caption_result.get("error"),
                    })
                    
                    if not caption_result.get("success"):
                        error_msg = caption_result.get("error", "Caption generation failed")
                        entry["image_quality"]["caption_generation"] = {
                            "status": "failed",
                            "error": error_msg,
                        }
                        caption = ""  # Empty caption will cause VisualizeQA to be skipped later.
                        print(f"[sample {index}] Caption generation failed: {error_msg}")
                    else:
                        caption = caption_result.get("caption", "")
                        entry["image_quality"]["caption_generation"] = {
                            "status": "success",
                            "caption": caption,
                        }
                        print(f"[sample {index}] Caption generated successfully")
                    
                    self._sleep(self.delay)
                except Exception as e:
                    error_msg = (
                        f"Step 5.2 (caption generation) raised an exception: "
                        f"{type(e).__name__}: {str(e)}"
                    )
                    print(f"[sample {index}] {error_msg}")
                    traceback.print_exc()
                    record_error("Step 5.2 (caption generation)", e, error_msg)
                    entry["image_quality"]["caption_generation"] = {
                        "status": "failed",
                        "error": error_msg,
                    }
                    caption = ""  # Empty caption will cause VisualizeQA to be skipped later.
                    # Caption generation failure does not abort the whole pipeline;
                    # it only causes VisualizeQA to be skipped.
            except Exception as e:
                error_msg = (
                    f"Step 5 (image quality check) raised an exception: "
                    f"{type(e).__name__}: {str(e)}"
                )
                print(f"[sample {index}] {error_msg}")
                traceback.print_exc()
                record_error("Step 5 (image quality check)", e, error_msg)
                entry["image_quality"] = {
                    "status": "failed",
                    "error": error_msg,
                }
                return finalize("failed", error_msg)
        else:
            if not self.figures_dir:
                reason = "Figure output directory is not configured"
            else:
                reason = "Figure file does not exist"
            entry["image_quality"] = {"status": "skipped", "reason": reason}
            caption = ""  # No image implies no caption.
        
        # Step 6: VisualizeQA - rewrite problem and CoT conditioned on the caption
        if caption and question and cot:
            print(f"Step 6: VisualizeQA [sample {index}]")
            try:
                plotting_code_for_vqa = entry.get("plotting", {}).get("plotting_code")
                actual_data_for_vqa = entry.get("plotting", {}).get("actual_data")
                
                visualize_result = self.visualize_qa.visualize_qa(
                    question=question,
                    cot=cot,
                    caption=caption,
                    index=index,
                    plotting_code=plotting_code_for_vqa,
                    actual_data=actual_data_for_vqa,
                )
                usage_records.append(visualize_result.get("usage"))
                
                debug_print("Step 6: VisualizeQA - full result", {
                    "success": visualize_result.get("success"),
                    "question": visualize_result.get("question"),
                    "cot": visualize_result.get("cot"),
                    "llm_response": visualize_result.get("llm_response"),
                    "usage": self._serialize_usage(visualize_result.get("usage")),
                    "finish_reason": visualize_result.get("finish_reason"),
                    "error": visualize_result.get("error"),
                })
                
                if not visualize_result.get("success"):
                    error_msg = visualize_result.get("error", "VisualizeQA failed")
                    entry["visualize_qa"] = {
                        "status": "failed",
                        "error": error_msg,
                    }
                    # VisualizeQA failure does not abort the pipeline; we only record the error.
                    print(f"[sample {index}] VisualizeQA failed: {error_msg}")
                else:
                    visualized_question = visualize_result.get("question", "")
                    visualized_cot = visualize_result.get("cot", "")
                    
                    entry["visualize_qa"] = {
                        "status": "success",
                        "question": visualized_question,
                        "cot": visualized_cot,
                    }
                    print(f"[sample {index}] VisualizeQA succeeded")
                
                self._sleep(self.delay)
            except Exception as e:
                error_msg = (
                    f"Step 6 (VisualizeQA) raised an exception: "
                    f"{type(e).__name__}: {str(e)}"
                )
                print(f"[sample {index}] {error_msg}")
                traceback.print_exc()
                record_error("Step 6 (VisualizeQA)", e, error_msg)
                entry["visualize_qa"] = {
                    "status": "failed",
                    "error": error_msg,
                }
        else:
            if not caption:
                reason = "Missing caption"
            elif not question:
                reason = "Missing question"
            elif not cot:
                reason = "Missing cot"
            else:
                reason = "Unknown reason"
            entry["visualize_qa"] = {"status": "skipped", "reason": reason}
            print(f"[sample {index}] Skip VisualizeQA: {reason}")
        
        
        # All steps passed
        return finalize("success")
    
    def _draw_geometry(
        self,
        plotting_code: Dict[str, Any],
        output_path: Path,
        index: Optional[int] = None,
    ) -> tuple[bool, Optional[Dict[str, Any]]]:
        """
        Render a realistic PNG figure from plotting_code using RealWorldPlotter.

        Returns:
            (success, actual_plotting_code): success flag and the plotting_code
            snapshot actually used for rendering (with pixel coordinates).

        Raises:
            Exception: Re-raised if rendering fails so the caller can record it.
        """
        try:
            result = self.visual_plotter.render_image(plotting_code, output_path, return_plotting_code=True)
            if isinstance(result, tuple):
                success, actual_plotting_code = result
                return success, actual_plotting_code
            else:
                return bool(result), None
        except Exception as e:
            # Print debugging information to help diagnose rendering issues.
            print(plotting_code)
            if index is not None:
                print(
                    f"[sample {index}] Exception while rendering geometry: "
                    f"{type(e).__name__}: {str(e)}"
                )
            traceback.print_exc()
            raise
            
    def process_file(
        self,
        input_file: Path,
        result_file: Path,
        start_index: int = 0,
        end_index: Optional[int] = None,
        max_workers: int = 1,
        log_file: Optional[Path] = None,
    ):
        """
        Process an input JSONL file and append each sample result to a JSONL log.

        Supports multi-threading: when ``max_workers > 1`` a thread pool is used.
        The default ``max_workers`` normally comes from configuration.

        Args:
            input_file: Path to the input JSONL file.
            result_file: Path to the JSONL result file.
            start_index: Starting index (for sharding / resume).
            end_index: Exclusive end index.
            max_workers: Maximum number of worker threads.
            log_file: Optional JSON statistics log file with aggregate metrics.
        """
        start_time = time.time()
        
        print(f"📖 Reading input file: {input_file}")
        with open(input_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
        
        total_lines = len(lines)
        end_idx = end_index if end_index is not None else total_lines
        print(f"Processing range: [{start_index}, {min(end_idx, total_lines)}) / total {total_lines}")
        print(f"💾 Result log: {result_file}")
        if max_workers > 1:
            print(f"🧵 Using {max_workers} worker threads")
        print("=" * 60)
        
        tasks: List[tuple[int, Dict[str, Any]]] = []
        for i in range(start_index, min(end_idx, total_lines)):
            line = lines[i].strip()
            if not line:
                continue
            try:
                tasks.append((i, json.loads(line)))
            except json.JSONDecodeError as exc:
                error_msg = f"Failed to parse JSON: {type(exc).__name__}: {str(exc)}"
                print(f"[sample {i}] {error_msg}")
                error_info = {
                    "index": i,
                    "step": "file_read (JSON parse)",
                    "error_type": type(exc).__name__,
                    "error_message": str(exc),
                    "traceback": traceback.format_exc(),
                    "timestamp": time.time(),
                }
                with self.errors_lock:
                    self.global_errors.append(error_info)
                temp_entry = {
                    "index": i,
                    "status": "failed",
                    "status_reason": error_msg,
                }
                self._append_log_entry(temp_entry, result_file, lock=None)
        
        stats_lock = threading.Lock()
        lock = threading.Lock() if max_workers > 1 else None
        stats = {"total": len(tasks), "success": 0, "failed": 0, "completed": 0, "tokens": 0}
        
        def process_task(idx: int, payload: Dict[str, Any]) -> Dict[str, Any]:
            result = self.process_one_item(idx, payload, result_file, lock=lock)
            with stats_lock:
                stats["completed"] += 1
                if result.get("status") == "success":
                    stats["success"] += 1
                else:
                    stats["failed"] += 1
                stats["tokens"] += result.get("token_usage", 0) or 0
                if stats["completed"] % 5 == 0 or stats["completed"] == stats["total"]:
                    print(
                        f"Progress: {stats['completed']}/{stats['total']} "
                        f"(success {stats['success']} failed {stats['failed']})"
                    )
            return result
        
        if max_workers > 1 and len(tasks) > 1:
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = {executor.submit(process_task, idx, payload): idx for idx, payload in tasks}
                for future in as_completed(futures):
                    try:
                        future.result()
                    except Exception as exc:
                        idx = futures[future]
                        error_msg = f"Worker task raised an exception: {type(exc).__name__}: {str(exc)}"
                        print(f"[sample {idx}] {error_msg}")
                        traceback.print_exc()
                        error_info = {
                            "index": idx,
                            "step": "task_execution",
                            "error_type": type(exc).__name__,
                            "error_message": str(exc),
                            "traceback": traceback.format_exc(),
                            "timestamp": time.time(),
                        }
                        with self.errors_lock:
                            self.global_errors.append(error_info)
                        temp_entry = {
                            "index": idx,
                            "status": "failed",
                            "status_reason": error_msg,
                        }
                        self._append_log_entry(temp_entry, result_file, lock=lock)
        else:
            for idx, payload in tasks:
                process_task(idx, payload)
        
        end_time = time.time()
        total_time = end_time - start_time
        
        print("\n" + "=" * 60)
        print("Processing completed")
        print(f"   total: {stats['total']}")
        print(f"   success: {stats['success']}")
        print(f"   failed: {stats['failed']}")
        if stats["total"]:
            print(f"   success rate: {stats['success'] / stats['total'] * 100:.1f}%")
        if stats["tokens"]:
            print(f"   total recorded tokens: {stats['tokens']}")
        print(f"   total wall time: {total_time:.2f} seconds")
        if stats['success'] > 0:
            avg_time_per_success = total_time / stats['success']
            print(f"   avg time per successful sample: {avg_time_per_success:.2f} seconds")
        print(f"Results written to: {result_file}")
        
        # Persist global errors to a separate file
        self._save_errors(result_file)
        
        if log_file:
            statistics = {
                "total_samples": stats['total'],
                "success_samples": stats['success'],
                "failed_samples": stats['failed'],
                "success_rate": stats['success'] / stats['total'] * 100 if stats['total'] > 0 else 0.0,
                "total_tokens": stats['tokens'],
                "total_time_seconds": round(total_time, 2),
                "average_time_per_success_seconds": round(total_time / stats['success'], 2) if stats['success'] > 0 else 0.0,
                "start_time": start_time,
                "end_time": end_time,
                "input_file": str(input_file),
                "result_file": str(result_file),
            }
            
            try:
                log_file.parent.mkdir(parents=True, exist_ok=True)
                with open(log_file, "w", encoding="utf-8") as f:
                    json.dump(statistics, f, ensure_ascii=False, indent=2)
                print(f"Statistics log saved: {log_file}")
            except Exception as e:
                print(f"Failed to save statistics log: {e}")
                traceback.print_exc()


def main():
    """CLI entry point for the end-to-end geometry pipeline."""
    parser = argparse.ArgumentParser(
        description="Run the geometry pipeline: generation -> judge -> plotting",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--input",
        type=Path,
        required=True,
        help="Path to the input JSONL file",
    )
    parser.add_argument(
        "--output",
        type=Path,
        required=True,
        help="Path to the output JSONL file (per-sample pipeline results)",
    )

    parser.add_argument(
        "--start_index",
        type=int,
        required=True,
        help="Start index (0-based, inclusive)",
    )
    parser.add_argument(
        "--end_index",
        type=int,
        required=True,
        help="End index (0-based, exclusive)",
    )
    parser.add_argument(
        "--log_file",
        type=Path,
        default=None,
        help="Optional JSON file for aggregate statistics about this run",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        default=False,
        help="Enable debug mode and print detailed model responses",
    )
    
    args = parser.parse_args()
    
    # Create the pipeline (CLI arguments take precedence over the configuration file)
    debug_mode = args.debug or Config.DEBUG
    
    pipeline = Pipeline(
        api_key=Config.API_KEY,
        model=Config.MODEL,
        generator_model=Config.GENERATOR_MODEL,
        judge_model=Config.JUDGE_MODEL,
        plotter_model=Config.PLOTTER_MODEL,
        vl_model=Config.VL_MODEL,
        visualize_qa_model=Config.VISUALIZE_QA_MODEL,
        base_url=Config.BASE_URL,
        delay=Config.DELAY,
        figures_dir=Config.FIGURES_DIR,
        max_retries=Config.MAX_RETRIES,
        debug=debug_mode,
    )
    
    # Process the input file
    pipeline.process_file(
        input_file=args.input,
        result_file=args.output,
        start_index=args.start_index,
        end_index=args.end_index,
        max_workers=Config.MAX_WORKERS,
        log_file=args.log_file,
    )


if __name__ == "__main__":
    main()
