from openai import OpenAI
import networkx as nx
import json
import os
from prompt_generate.util import load_env, logger, PydanticFormatError

import typing
from typing import TypedDict, Literal, TypeVar, overload, Type, get_args, Any, Callable
from collections.abc import Iterable

from pydantic import BaseModel, Field, model_validator
from typing import Literal, Self

load_env()

format_retry_times = 3

DevSummary = Literal["system_start"]
UserSummary = Literal[
    "factors_start", 
    "factors_rule_check_not_pass", 
    "factors_self_check",
    "graph_start",
    "graph_rule_check_not_pass",
    "graph_self_check",
    "rules_start",
    "rules_rule_check_not_pass",
    "rules_self_check"
]
AssistantSummary = Literal[
    "factors_proposed", 
    "factors_keep", 
    "factors_regenerate_from_graph", 
    "graph_proposed", 
    "graph_keep", 
    "graph_regenerate_from_rules", 
    "rules_proposed", 
    "rules_keep"
]
StepSummary = Literal[
    DevSummary | UserSummary | AssistantSummary
]
# a list of str instead of typehint
step_summary_all: list[StepSummary] = list(sum([list(get_args(s)) for s in (DevSummary, UserSummary, AssistantSummary)], start=[])) # type: ignore

class Message(TypedDict):
    role: Literal["system", "developer", "user", "assistant"]
    content: str

class MessageWithReason(Message):
    reasoning: str | None

structured_output_T = TypeVar("structured_output_T", bound=BaseModel)

from prompt_generate.util import FactorList, Graph, RulesJson, parse_completion

class FactorListOrKeep(BaseModel):
    """
    If there is a factor list in the last step and no need to change it, set keep_flag to "KEEP_FACTOR" and return None for `factor_list`. Otherwise, set keep_flag to None and return the new factor list for `factor_list`.
    """
    factor_list: FactorList | None
    keep_flag: Literal["KEEP_FACTOR"] | None
    
    @model_validator(mode='after')
    def only_one_not_none(self) -> Self:
        if (self.factor_list is None) == (self.keep_flag is None):
            raise PydanticFormatError("Only one of factor_list and keep_flag can be not None.")
        return self

class GraphOrKeep(BaseModel):
    """
    If there is a graph in the last step and no need to change it, set keep_flag to "KEEP_GRAPH" and return None for `graph`. Otherwise, set keep_flag to None and return the new graph for `graph`.
    """
    graph: Graph | None
    keep_flag: Literal["KEEP_GRAPH"] | None = Field(
        ...,
        description = 'Only value when one graph has been provided in the last step and no need to change it. If so, return "KEEP_GRAPH" for keep_flag and None for graph. Otherwise, return None for keep_flag and the new graph for graph.'
    )
    
    @model_validator(mode='after')
    def only_one_not_none(self) -> Self:
        if (self.graph is None) == (self.keep_flag is None):
            raise PydanticFormatError("Only one of graph and keep_flag can be not None.")
        return self

class GraphKeepOrRegenerate(GraphOrKeep):
    """
    If you find the factors are not good enough and need to regenerate the factors, explain the reason and return your modified factors in `regenerte_factors`. Otherwise, set `regenerate_factors_reason` and `regenerte_factors` to None.
    """
    regenerate_factors_reason: str | None
    regenerate_factors: FactorList | None
    
    @model_validator(mode='after')
    def both_reason_and_content(self) -> Self:
        if (self.regenerate_factors is None) != (self.regenerate_factors_reason is None):
            raise PydanticFormatError("Both regenerate_factors and regenerate_factors_reason should be None or not None.")
        return self
    
    @model_validator(mode='after')
    def only_one_not_none(self) -> Self:
        if (self.graph is None) and (self.keep_flag is None) and (self.regenerate_factors is None):
            raise PydanticFormatError("At least one of graph, keep_flag and regenerate_factors should be not None.")
        if (self.graph is not None) + (self.keep_flag is not None) + (self.regenerate_factors is not None) > 1:
            raise PydanticFormatError("At most one of graph, keep_flag and regenerate_factors can be not None.")
        return self

class RulesJsonOrKeep(BaseModel):
    """
    If there is a rules json in the last step and no need to change it, set keep_flag to "KEEP_RULES" and return None for `rules_json`. Otherwise, set keep_flag to None and return the new rules json for `rules_json`.
    """
    rules_json: RulesJson | None
    keep_flag: Literal["KEEP_RULES"] | None = Field(
        ...,
        description = 'Only value when one rules json has been provided in the last step and no need to change it. If so, return "KEEP_RULES" for keep_flag and None for rules_json. Otherwise, return None for keep_flag and the new rules_json for rules_json.'
    )
    @model_validator(mode='after')
    def only_one_not_none(self) -> Self:
        if (self.rules_json is None) == (self.keep_flag is None):
            raise PydanticFormatError("Only one of rules_json and keep_flag can be not None.")
        return self

class RulesKeepOrRegenerate(RulesJsonOrKeep):
    """
    If you find the graph are not good enough and need to regenerate the graph, explain the reason and return your modified graph in `regenerte_graph`. Otherwise, set `regenerate_graph_reason` and `regenerte_graph` to None.
    """
    regenerate_graph_reason: str | None
    regenerate_graph: Graph | None
    @model_validator(mode='after')
    def both_reason_and_content(self) -> Self:
        if (self.regenerate_graph is None) != (self.regenerate_graph_reason is None):
            raise PydanticFormatError("Both regenerate_graph and regenerate_graph_reason should be None or not None.")
        return self
    @model_validator(mode='after')
    def only_one_not_none(self) -> Self:
        if (self.rules_json is None) and (self.keep_flag is None) and (self.regenerate_graph is None):
            raise PydanticFormatError("At least one of rules_json, keep_flag and regenerate_graph should be not None.")
        if (self.rules_json is not None) + (self.keep_flag is not None) + (self.regenerate_graph is not None) > 1:
            raise PydanticFormatError("At most one of rules_json, keep_flag and regenerate_graph can be not None.")
        return self
    
class TokenUsageMixin:
    def _init_usage(self) -> None:
        self.total_usage: dict[str, int | dict[str, int]] = {
            "input_tokens": 0,
            "output_tokens": 0,
            "total_tokens": 0,
            "input_tokens_details": {
                "cached_tokens": 0
            },
            "output_tokens_details": {
                "reasoning_tokens": 0
            }
        }
        
    @typing.no_type_check
    def update_usage(self, completion: Any) -> None:
        u = completion.usage
        self.total_usage["input_tokens"] += u.input_tokens
        self.total_usage["output_tokens"] += u.output_tokens
        self.total_usage["total_tokens"]  += u.total_tokens
        self.total_usage["input_tokens_details"]["cached_tokens"] += u.input_tokens_details.cached_tokens
        self.total_usage["output_tokens_details"]["reasoning_tokens"] += u.output_tokens_details.reasoning_tokens
        
class TraceSummaryMixin:
    def _init_trace(self) -> None:
        self._trace_summary: list[StepSummary] = []
    
    def trace_summary_append(self, text: StepSummary) -> None:
        if text not in step_summary_all:
            raise ValueError(f"The step summary {text} is not in the allowed list {step_summary_all}.")
        self._trace_summary.append(text)
        self._check_trace_summary()
    
    @property
    def trace_summary(self) -> list[StepSummary]:
        self._check_trace_summary(check_length=True)
        return self._trace_summary
    
    @trace_summary.setter
    def trace_summary(self, value: Iterable[StepSummary]) -> None:
        self._trace_summary = list(value)
        self._check_trace_summary()
    
    def _check_trace_summary(self, check_length: bool = False) -> None:
        """
        Do not check_length during the append or set operation. Only check_length when get the trace_summary property.
        """
        assert hasattr(self, "messages"), "The messages attribute is not set. Please make sure to initialize the messages attribute before checking the trace summary."
        
        if check_length and len(self._trace_summary) != len(self.messages): # type: ignore
            raise ValueError(f"The length of trace_summary {len(self._trace_summary)} is not equal to the length of messages {len(self.messages)}.") # type: ignore
        last_category = None
        for step_summary in self._trace_summary:
            if step_summary not in step_summary_all:
                raise ValueError(f"The step summary {step_summary} is not in the allowed list {step_summary_all}.")
            if last_category is None and step_summary not in get_args(DevSummary):
                raise ValueError(f"The first step summary {step_summary} must be in the developer category {get_args(DevSummary)}.")
            elif last_category == "developer" and step_summary not in get_args(UserSummary):
                raise ValueError(f"The step summary {step_summary} must be in the user category {get_args(UserSummary)} after a developer step.")
            elif last_category == "user" and step_summary not in get_args(AssistantSummary):
                raise ValueError(f"The step summary {step_summary} must be in the assistant category {get_args(AssistantSummary)} after a user step.")
            elif last_category == "assistant" and step_summary not in get_args(UserSummary) + get_args(DevSummary):
                raise ValueError(f"The step summary {step_summary} must be in the user category {get_args(UserSummary)} or the developer category {get_args(DevSummary)} after an assistant step.")
            else:
                if step_summary in get_args(DevSummary):
                    last_category = "developer"
                elif step_summary in get_args(UserSummary):
                    last_category = "user"
                elif step_summary in get_args(AssistantSummary):
                    last_category = "assistant"
                    
class CacheLoadMixin:
    def _init_cache(self, cache_file: str) -> None:
        self.cache_file = cache_file
        self.cache_step: int = 0
        self.have_loaded_step: int = 0
        
    def load_cache(self, cache: dict[str, Any]) -> None:
        self.model = cache["model"]
        self.scenario = cache["scenario"]
        self.messages_with_reason = cache["history"]
        self.messages = [{"role": message["role"], "content": message["content"]} for message in self.messages_with_reason]
        self.total_usage = cache["usage"]
        self.cache_step = sum([1 if message["role"] == "assistant" else 0 for message in self.messages])
        self.trace_summary = cache["trace_summary"]
        self.have_loaded_step = 0
        
        for key,value in cache['config'].items():
            if getattr(self, key, None) is None:
                setattr(self, key, value)
            elif getattr(self, key) != value:
                    logger.warning(f"The value of {key} in the cache file ({value}) is different from the value in the current setting ({getattr(self, key)}). Use the value in the cached file.")
                    setattr(self, key, value)
            else:
                pass
            
        logger.debug(f"Load the cached file {self.cache_file} with {len(self.messages)} messages and {self.total_usage} usage.")
        logger.debug(f"The config in the cached file is {cache['config']}.")
        
    def _process_cached_step(self, structured_output: Type[structured_output_T] | None = None) -> str | structured_output_T | None:
        """
        If the `self.have_loaded_step` is less than `self.cache_step`, directly return the assistant's response in the cache.
        """
        # the self.cache_step is equal to the number of assistant messages in the cache file.
        # the current_step is initialized to 0, and increase by 1 after each call of request_and_record function.
        return_message = None
        
        if not(self.cache_step > 0 and self.have_loaded_step < self.cache_step):
            # no cached message can be used
            return None
        
        current_step = self.have_loaded_step
        # find the current_step-th assistant message in the cache
        for step_index in range(len(self.messages_with_reason)):
            if self.messages_with_reason[step_index]['role'] == 'assistant':
                if current_step == 0:
                    return_message = self.messages_with_reason[step_index]['content']
                    break
                else:
                    current_step -= 1
        else:
            raise RuntimeError(f"Cannot find the cached assistant message for step {self.have_loaded_step}.")
                
        # update the current_step
        self.have_loaded_step += 1

        assert return_message is not None, "Cannot find the cached assistant message."
        logger.debug(f'Try to load cached step {self.have_loaded_step + 1}\n{return_message}')
        if structured_output is None:
            return return_message
        else:
            json_content = json.loads(return_message)
            structured_output_obj = structured_output.model_validate(json_content)
            logger.debug(f'Loaded structured output correctly for cached step {self.have_loaded_step}\n with type {type(structured_output)} and content {structured_output_obj.model_dump_json()}', )
            return structured_output_obj


class SinglePromptGeneratorNoExample(TokenUsageMixin, TraceSummaryMixin, CacheLoadMixin):

    def __init__(self,
        model: str,
        scenario: str,
        idx: int,
        time: int,
        cache_file: str,
        output_file: str,
        graph_file: str,
        graph_img_file: str | None,
        graph_img_type: str,
        rule_file: str,
        self_check_times: int = 0,
        preview: bool = False,
        allow_regen_graph: bool = False,
        allow_regen_factors: bool = False,
        max_regen_graph_time: int = 3,
        max_regen_factors_time: int = 3,
        complex: Literal["low", "medium", "high"] = "low",
        max_nodes: int | None = None,
        reasoning_effort: Literal["minimal", "low", "medium", "high"] = "minimal",
        max_output_tokens: int = 12800,
        remind_last: bool = False,
        split_steps: bool = False,
        ) -> None:
        self.model = model
        self.client = OpenAI()
        self.scenario = scenario
        self.idx = idx # only used for logging
        self.time = time # only used for logging
        self.output_file = output_file
        self.rule_file = rule_file
        self.graph_file = graph_file
        self.graph_img_file = graph_img_file
        self.graph_img_type = graph_img_type
        self.self_check_times = self_check_times
        self.preview = preview
        self.allow_regen_graph = allow_regen_graph if max_regen_graph_time > 0 else False
        self.allow_regen_factors = allow_regen_factors if max_regen_factors_time > 0 else False
        self.max_regen_graph_time = max_regen_graph_time if self.allow_regen_graph else 0
        self.max_regen_factors_time = max_regen_factors_time if self.allow_regen_factors else 0
        self.complex = complex
        self.max_nodes = max_nodes
        self.max_output_tokens = max_output_tokens
        self.reasoning_effort = reasoning_effort
        self.remind_last = remind_last
        self.split_steps = split_steps
        
        # mixin init
        self._init_usage()
        self._init_trace()
        self._init_cache(cache_file)
        
        self._init_messages()
        
    def complex_range(self) -> tuple[int, int]:
        if self.complex == "low":
            return (2, 5)
        elif self.complex == "medium":
            return (3, 7)
        elif self.complex == "high":
            return (5, 10)
        else:
            raise ValueError(f"The complex level {self.complex} is not supported.")

    def _init_messages(self) -> None:
        self.messages: list[Message] = []
        self.messages_with_reason: list[MessageWithReason] = []
        from prompt_generate.prompts.prompts_zero_example import message_system, complex_prompt, gpt5_system_prompt

        # low_node, high_node = self.complex_range()
        # complex_prompt_filled = complex_prompt.format(low_node=low_node, high_node=high_node)
        # message_system += "\n\n" + complex_prompt_filled
        
        # if getattr(self, "max_nodes", None):
        #     from prompt_generate.prompts.prompts_zero_example import node_limit_note
        #     message_system += "\n\n" + node_limit_note.format(node_limit=self.max_nodes)
        
        message_system = gpt5_system_prompt

        if "o1" in self.model and self.preview:
            raise NotImplementedError("The preview mode has been deprecated.")
            self.messages = [{"role": "user", "content": message_system}]
            self.messages_with_reason = [{"role": "user", "content": message_system, "reasoning": None}]
        else:
            self.messages = [{"role": "developer", "content": message_system}]
            self.messages_with_reason = [{"role": "developer", "content": message_system, "reasoning": None}]
            self.trace_summary = ["system_start"]
            
    def _cutoff_message_for_request(self, messages: list[Message]) -> list[Message]:
        """
        If self.split_steps is True, cutoff too long messages for request. Only keep graph (rules) messages during graph (rules) generation step. 
        If self.split_steps is False, do not cutoff any messages for request.
        """
        if not self.split_steps:
            logger.debug('The `split_steps` is False, do not cutoff any messages for request.')
            return messages
        # find the last one summary in part_start_step
        last_part_start_idx = -1
        for i in range(len(self._trace_summary)-1, -1, -1):
            if self._trace_summary[i] in get_args(DevSummary):
                last_part_start_idx = i
                break
        if last_part_start_idx == -1:
            raise ValueError(f"There is no part start step in the trace summary {self._trace_summary}.")
        
        # keep the messages include and after last_part_start_idx
        messages = messages[last_part_start_idx:]
        logger.debug(f'The `split_steps` is True, cutoff messages before the last part start step for request. The current messages are from step {last_part_start_idx} to step {len(self.messages) - 1}.')
        assert messages[0]['role'] == "developer"
        return messages

    @overload
    def request_and_record(
        self,
        new_user_message: str,
        user_summary: UserSummary,
        assistant_summary_function: Callable[[str], AssistantSummary],
        structured_output: None = None,
        new_system_message: str | None = None,
        system_summary: DevSummary | None = None,
        max_output_tokens: int | None = None,
    ) -> str:
        ...

    @overload
    def request_and_record(
        self,
        new_user_message: str,
        user_summary: UserSummary,
        assistant_summary_function: Callable[[structured_output_T], AssistantSummary],
        structured_output: Type[structured_output_T],
        new_system_message: str | None = None,
        system_summary: DevSummary | None = None,
        max_output_tokens: int | None = None,
    ) -> structured_output_T:
        ...

    def request_and_record(self,
        new_user_message: str,
        user_summary: UserSummary,
        assistant_summary_function: Callable[[structured_output_T], AssistantSummary] | Callable[[str], AssistantSummary],
        structured_output: Type[structured_output_T] | None = None,
        new_system_message: str | None = None,
        system_summary: DevSummary | None = None,
        max_output_tokens: int | None = None,
        ) -> str | structured_output_T:
        """Add the provided text as the new user prompt and request the reply from the client. Add the response into the self.messages"""
        

        if (cached_output := self._process_cached_step(structured_output)) is not None:
            return cached_output
        
        if not self.split_steps and new_system_message is not None:
            logger.debug('`split_steps` flag is False, so the provided new_system_message will be include in the user message directly.')
            new_user_message = new_system_message + "\n\n" + new_user_message
        elif self.split_steps and new_system_message is not None:
            assert system_summary is not None
            self.trace_summary_append(system_summary)
            self.messages.append({"role": "developer", "content": new_system_message})
            self.messages_with_reason.append({"role": "developer", "content": new_system_message, "reasoning": None})
        elif self.split_steps and new_system_message is None:
            assert system_summary is None
            
        self.trace_summary_append(user_summary)
        self.messages.append({"role": "user", "content": new_user_message})
        self.messages_with_reason.append({"role": "user", "content": new_user_message, "reasoning": None})

        max_output_tokens = max_output_tokens if max_output_tokens is not None else self.max_output_tokens
        if max_output_tokens != self.max_output_tokens:
            logger.warning(f"The max_output_tokens has been changed from {self.max_output_tokens} to {max_output_tokens}.")
            
        request_messages = self._cutoff_message_for_request(self.messages)

        if structured_output is not None:
            logger.debug(f"Request OpenAI api with structured output type {structured_output.__name__}.")
            logger.debug(json.dumps(request_messages, indent=2))
            error_msg = None
            for _ in range(format_retry_times):
                try:
                    if error_msg is not None:
                        logger.debug(f'The error msg {error_msg} will be included in the message for retry.')
                        request_messages[-1]['content'] += f"\n\nNote: The previous response from the model is not valid. The error is: {error_msg}. Please make sure your response is in the correct format."
                    completion = self.client.responses.parse(
                        model=self.model,
                        input=request_messages, # type: ignore
                        reasoning={"effort": self.reasoning_effort, "summary": "auto"}, # type: ignore
                        text_format=structured_output,
                        max_output_tokens=max_output_tokens,
                    )
                    break
                except PydanticFormatError as e:
                    error_msg = str(e)
                    logger.warning(f"[WARNING] The response from OpenAI api is not valid. Error info: {e}. Retry {_+1}/{format_retry_times}.")
                    continue
            else:
                raise RuntimeError(f"The response from OpenAI api is not valid after {format_retry_times} retries. Last error info: {error_msg}")

        else:
            logger.debug("Request OpenAI api with plain text output.")
            logger.debug(json.dumps(request_messages, indent=2))
            completion = self.client.responses.parse(
                model=self.model,
                input=request_messages, # type: ignore
                reasoning={"effort": self.reasoning_effort, "summary": "auto"}, # type: ignore
                max_output_tokens=max_output_tokens,
            )

        self.update_usage(completion)  # Fix: accumulate usage for responses w/ or w/o structured output

        # The parse way is suitable for both structured_output is None or not None.
        if completion is None: # type: ignore
            raise RuntimeError("The response from OpenAI api is None.")
        elif len(completion.output) != 2:
            raise RuntimeError(f"The response from OpenAI api is not valid. Expected two parts in the output (reasoning and generation) but got {len(completion.output)} parts.")

        thinking, generation_obj, generation_str = parse_completion(completion)

        logger.debug(f"Get the response from OpenAI api:\n{generation_str}\nWith reasoning:\n{thinking}")
        if structured_output is not None:
            logger.debug(f"Returned structured output type {type(generation_obj).__name__}:\n{generation_obj}")

        self.messages.append({"role": "assistant", "content": generation_str})
        self.messages_with_reason.append({"role": "assistant", "content": generation_str, "reasoning": thinking})
        
        assistant_summary = assistant_summary_function(generation_obj if structured_output is not None else generation_str)
        self.trace_summary_append(assistant_summary)


        os.makedirs(os.path.dirname(self.cache_file), exist_ok=True)
        config = {
            "self_check_times": self.self_check_times,
            "allow_regen_graph": self.allow_regen_graph,
            "max_regen_graph_time": self.max_regen_graph_time,
            "allow_regen_factors": self.allow_regen_factors,
            "max_regen_factors_time": self.max_regen_factors_time,
            "split_steps": self.split_steps,
            "remind_last": self.remind_last,
        }
        with open(self.cache_file, 'w') as f:
            f.write(json.dumps({
                "model": self.model,
                "scenario": self.scenario,
                "trace_summary": self.trace_summary,
                "history": self.messages_with_reason,
                "usage": self.total_usage,
                "config": config},
                indent=2
            ))

        return generation_obj if structured_output is not None else generation_str


    def save_result(self, graph: Graph, rules_json: RulesJson) -> None:
        # just remove the cache file into the output file
        os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
        os.rename(self.cache_file, self.output_file)
        logger.info(f"Save the result into {os.path.abspath(self.output_file)}")
        # Now, try to extract the final dot file and json file
        os.makedirs(os.path.dirname(self.graph_file), exist_ok=True)
        os.makedirs(os.path.dirname(self.rule_file), exist_ok=True)
        if self.graph_img_file is not None:
            os.makedirs(os.path.dirname(self.graph_img_file), exist_ok=True)

        with open(self.graph_file, 'w') as f:
            f.write(graph.model_dump_json(indent=2))
            logger.info(f"The graph file of sample {self.idx} (time {self.time}) has been saved into {os.path.abspath(self.graph_file)}")
            if self.graph_img_file is not None:
                from matplotlib import pyplot as plt
                from networkx.drawing.nx_agraph import to_agraph
                G = graph.to_digraph()
                A = to_agraph(G)
                A.layout('dot')
                A.draw(self.graph_img_file)
                # nx.draw_networkx(graph.to_digraph(), with_labels=True)
                # plt.savefig(self.graph_img_file, format=self.graph_img_type)
                logger.info(f"The image file of sample {self.idx} (time {self.time}) has been saved into {os.path.abspath(self.graph_img_file)}")
                # plt.clf()
                
            output_json = {}
            output_json["scenario"] = self.scenario
            output_json.update(rules_json.model_dump(mode='json'))
            output_json['rules_dict'] = rules_json.rules_convert_to_dict()
            with open(self.rule_file, 'w') as f:
                f.write(json.dumps(output_json, indent=2))
            logger.info(f"The json file of sample {self.idx} (time {self.time}) has been saved into {os.path.abspath(self.rule_file)}")

    def _check_graph_and_factors_consistent(self, graph: Graph, factors: FactorList) -> tuple[bool, str | None]:
        """
        1. Check whether all the factors and results are included in the graph.
        2. Check whether no additional nodes are included in the graph.
        3. Check whether each factor has no incoming edge and each result has at least one incoming edge.
        """
        factor_names = [factor.name for factor in factors.factors if factor.type == 'factor']
        result_names = [factor.name for factor in factors.factors if factor.type == 'result']
        graph_nodes = graph.nodes
        missing_in_graph = []
        error_message = []
        for name in factor_names + result_names:
            if name not in graph_nodes:
                missing_in_graph.append(name)
        if missing_in_graph:
            error_message.append(f"The nodes {missing_in_graph} are included in the factors but not in the graph. Please make sure all the factors and results are included in the graph.")
        extra_in_graph = []
        for name in graph_nodes:
            if name not in factor_names + result_names:
                extra_in_graph.append(name)
        if extra_in_graph:
            error_message.append(f"The nodes {extra_in_graph} are included in the graph but not in the factors. Please make sure no additional nodes are included in the graph.")
        wrong_factors = []
        graph_nx = graph.to_digraph()
        for name in factor_names:
            if graph_nx.in_degree(name) != 0:
                wrong_factors.append(name)
        if wrong_factors:
            error_message.append(f"The factor nodes {wrong_factors} have incoming edges in the graph. Please make sure each factor has no incoming edge.")
        wrong_results = []
        for name in result_names:
            if graph_nx.in_degree(name) == 0:
                wrong_results.append(name)
        if wrong_results:
            error_message.append(f"The result nodes {wrong_results} have no incoming edges in the graph. Please make sure each result has at least one incoming edge.")
        if error_message:
            return False, "\n\n".join(error_message)
        return True, None

    def check_graph(self, graph: Graph | None, factors: FactorList) -> tuple[Literal[True], None] | tuple[Literal[False], str]:
        """
        Check if the graph is valid.
        1. Check whether all the factors and results are included in the graph.
        2. The graph should be a directed acyclic graph.
        3. Check whether there is isolated nodes.
        4. Check node count upper bound.
        """
        if graph is None:
            return False, "There is no valid dot file can be extracted from your answer. Please make sure your answer in the correct format. The dot content should be covered by <dot> and </dot>."
        # 1. Check whether all the factors and results are included in the graph.
        pass_check, error_msg = self._check_graph_and_factors_consistent(graph, factors)
        if not pass_check:
            assert error_msg is not None
            return False, error_msg
        graph_nx = graph.to_digraph()
        # 2. The graph should be a directed acyclic graph.
        if not nx.is_directed_acyclic_graph(graph_nx):
            return False, "There is a cycle in the graph. Please check the graph again and regenerate the graph in the format defined above."
        # 3. Check whether there is isolated nodes.
        if len(list(nx.isolates(graph_nx))) > 0:
            return False, "There are isolated nodes: {} in the graph. Please add related edges or remove the nodes. Regenerate the causal graph in the format defined above.".format(list(nx.isolates(graph_nx)))

        # 4. Check node count upper bound
        if getattr(self, "max_nodes", None) and graph_nx.number_of_nodes() > self.max_nodes:
            return False, (
                f"The graph contains {graph_nx.number_of_nodes()} nodes, which exceeds the "
                f"limit of {self.max_nodes}. Please regenerate a smaller graph by keeping "
                f"only the most important/common nodes, merging near-duplicates, "
                f"and dropping trivial nodes."
            )

        return True, None

    def check_json_and_rule(self, json_file: RulesJson | None, graph: Graph | None) -> tuple[Literal[True], None] | tuple[Literal[False], str]:
        if json_file is None:
            return False, "There is no valid json file can be extracted from your answer. Please make sure your answer in the correct format. The json content should be covered by <json> and </json>."
        if not isinstance(json_file, RulesJson):
            return False, f"The json file is not in the correct format. Expected a RulesJson object but get a {type(json_file)}, Please make sure your answer is in correct format."
        # 1. check if the root and non_root are consistent with the graph
        if graph is not None:
            roots, non_roots = graph.root_and_non_root
            jmg = set(json_file.get_roots_names()) - set(roots)
            gmj = set(roots) - set(json_file.get_roots_names())
            if jmg or gmj:
                return False, "The roots in the json file is not consistent with the graph. " + ("The roots in the json file but not in the graph are: {}.".format(str(jmg)) if jmg else "") + ("The roots in the graph but not in the json file are: {}.".format(str(gmj)) if gmj else "")
            jmg = set(json_file.get_non_roots_names()) - set(non_roots)
            gmj = set(non_roots) - set(json_file.get_non_roots_names())
            if jmg or gmj:
                return False, "The non-roots in the json file is not consistent with the graph. " + ("The non-roots in the json file but not in the graph are: {}.".format(str(jmg)) if jmg else "") + ("The non-roots in the graph but not in the json file are: {}.".format(str(gmj)) if gmj else "")
        # 2. check the rules format
        rules_dict = json_file.rules_convert_to_dict()
        for target_node, rules in rules_dict.items():
            # json_file["rules"] contain all the rules for the non-roots, where each element "rules" is a list[dict[str, bool]], where it is a DNF.
            if target_node not in json_file.get_non_roots_names():
                return False, f"The target node {target_node} is not in the non-roots."
            if not isinstance(rules, list):
                return False, f"The rules for target node {target_node} is not in the correct format. It is not a list but a {type(rules)}."
            if graph is not None:
                parent_in_graph = [node for node in graph.nodes if graph.has_edge(node, target_node)]
            else:
                parent_in_graph = None
            non_occurred = parent_in_graph.copy() if parent_in_graph is not None else None
            for rule in rules:
                # rule: dict[str, bool]], a conjunction clause.
                if not isinstance(rule, dict):
                    return False, f"The rules {rules} for target node {target_node} is not in the correct format. The element {rule} in the list is not a dict but a {type(rule)}."
                for parent_node, value in rule.items():
                    if not isinstance(parent_node, str) or not isinstance(value, bool):
                        return False, f"The rule {rule} for target node {target_node} is not in the correct format. It should be a dict[str, bool]."
                    if parent_in_graph is not None and parent_node not in parent_in_graph:
                        return False, f"The cause {parent_node} in the rule {rule} for target node {target_node} is not in the parent list in the causal graph."
                    if non_occurred is not None and parent_node in non_occurred:
                        non_occurred.remove(parent_node)
            if non_occurred:
                return False, f"The parent nodes {non_occurred} in the causal graph are not included in the rules for target node \"{target_node}\"."
        return True, None

    def run_graph(self, allow_regen_factors: bool, factor_list: FactorList, generated_graph: Graph | None = None) -> tuple[Graph, RulesJson] | FactorList | None:
        """
        If the generated_graph is None, generate the graph first. Otherwise, directly use the provided graph as the initial graph.
        """
        
        def sublogic_of_regen(graph_keep_regen: GraphKeepOrRegenerate | GraphOrKeep) -> None | FactorList:
            if getattr(graph_keep_regen, "regenerate_factors", None) is not None:
                assert isinstance(graph_keep_regen, GraphKeepOrRegenerate), "The structured output should be GraphKeepOrRegenerate when allow_regen_factors is True."
                assert graph_keep_regen.regenerate_factors is not None
                return graph_keep_regen.regenerate_factors
            
        def assistant_summary_function_graph(output: GraphKeepOrRegenerate | GraphOrKeep) -> AssistantSummary:
            if isinstance(output, GraphKeepOrRegenerate) and output.regenerate_factors is not None:
                return "factors_regenerate_from_graph"
            if output.keep_flag is not None:
                return "graph_keep"
            else:
                return "graph_proposed"
            
        json_reply: RulesJson | None = None # initialize to compress the warning of type hint
        
        if allow_regen_factors:
            from prompt_generate.prompts.prompts_zero_example import message_regenerate_factors_allow
            message_regenerate_factors_allow = message_regenerate_factors_allow
        else:
            message_regenerate_factors_allow = ""
        
        # The loop is used to regenerate the graph.
        for regen_graph_time in range(self.max_regen_graph_time + 1):
            if generated_graph is None: # generate the graph first
                from prompt_generate.prompts.prompts_zero_example import message_analyze_graph

                if self.remind_last:
                    from prompt_generate.prompts.prompts_zero_example import remind_factors
                    message_analyze_graph_user = remind_factors.format(
                        scenario=self.scenario,
                        factors = factor_list.model_dump_json(indent=2),
                    )
                else:
                    message_analyze_graph_user = ""

                graph_keep_regen = self.request_and_record(
                    new_system_message=message_analyze_graph,
                    system_summary="system_start",
                    new_user_message=message_analyze_graph_user,
                    user_summary="graph_start",
                    structured_output=GraphKeepOrRegenerate if allow_regen_factors else GraphOrKeep,
                    assistant_summary_function = assistant_summary_function_graph,
                    )
                    
                factorlist_or_none = sublogic_of_regen(graph_keep_regen)
                if factorlist_or_none is not None:
                    # If the factors should be regenerated, break the loop and give the flag to regenerate the factors
                    return factorlist_or_none
                
                assert graph_keep_regen.keep_flag is None, "At the first step, the graph should be provided."
                assert graph_keep_regen.graph is not None, "At the first step, the graph should be provided."
                generated_graph = graph_keep_regen.graph
                assert generated_graph is not None, "At the first step, the graph should be provided."
            else:
                pass # directly use the provided graph as the initial graph.

            # 2.1 Check the causal graph and retrieve groundtruth graph --- self-check & rule-check (total time <= self.self_check_times)
            next_time_short = False # A flag to indicate whether the next self-check prompt should be short.
            for i in range(self.self_check_times):
                # 2.1.1 rule-based check
                pass_check, message_check = self.check_graph(generated_graph, factors=factor_list)
                if not pass_check:
                    # if any rules of the causal graph has been broken.
                    assert message_check is not None
                    graph_keep_regen = self.request_and_record(
                        new_user_message=message_check + message_regenerate_factors_allow, 
                        user_summary="graph_rule_check_not_pass",
                        assistant_summary_function=assistant_summary_function_graph,
                        structured_output=GraphKeepOrRegenerate if allow_regen_factors else GraphOrKeep)
                    next_time_short = False
                else:
                    # 2.1.2 check some vague standard, which cannot be checked by the check_graph function
                    from prompt_generate.prompts.prompts_zero_example import message_self_check_graph, message_self_check_again
                    graph_keep_regen = self.request_and_record(
                        new_user_message=(message_self_check_graph if not next_time_short else message_self_check_again.format(object="causal graph")) + message_regenerate_factors_allow,
                        user_summary="graph_self_check",
                        assistant_summary_function=assistant_summary_function_graph,
                        structured_output=GraphKeepOrRegenerate if allow_regen_factors else GraphOrKeep
                        )
                    next_time_short = True
                    
                factorlist_or_none = sublogic_of_regen(graph_keep_regen)
                if factorlist_or_none is not None:
                    # If the factors should be regenerated, break the loop and give the flag to regenerate the factors
                    return factorlist_or_none
                
                if graph_keep_regen.keep_flag is not None:
                    # If the graph is good enough, break the loop and use the current graph.
                    generated_graph = generated_graph
                    break
                else:
                    assert graph_keep_regen.graph is not None, "The graph should be provided when the keep_flag is None."
                    generated_graph = graph_keep_regen.graph
                assert generated_graph is not None, "The graph should be provided when the keep_flag is None."
                assert isinstance(generated_graph, Graph), "The graph should be provided when the keep_flag is None."

            # 3. Generate the rules based on the graph
            allow_regen_graph = self.allow_regen_graph and (regen_graph_time != self.max_regen_graph_time)
            logger.info(f"Start to generate the rules for sample {self.idx} (time {self.time}), regen_graph_time {regen_graph_time} with max_regen_graph {self.max_regen_graph_time}.")
            json_reply, regen_graph_or_none = self.run_rules(
                rules=None,
                factors=factor_list,
                groundtruth_graph=generated_graph,
                allow_regen_graph=allow_regen_graph,
            )

            # 4. If the graph should be regenerated, the regen_graph is not None, regenerate the graph. Otherwise, break the loop.
            if regen_graph_or_none is None:
                break # break will go to the return step
            else:
                assert isinstance(regen_graph_or_none, Graph)
                generated_graph = regen_graph_or_none
                continue # regenerate the graph

        if json_reply is None:
            raise RuntimeError("The rules json should be provided after the run_rules function.")
        if generated_graph is None:
            raise RuntimeError("The graph should be generated after the run_graph function.")
        return generated_graph, json_reply

    def run_rules(self, rules: RulesJson | None, factors: FactorList, groundtruth_graph: Graph, allow_regen_graph: bool) -> tuple[RulesJson, None] | tuple[None, Graph] | tuple[None, None]:
        
        def sublogic_of_regen(rules_keep_regen: RulesKeepOrRegenerate | RulesJsonOrKeep) -> None | Graph:
            if getattr(rules_keep_regen, "regenerate_graph", None) is not None:
                assert isinstance(rules_keep_regen, RulesKeepOrRegenerate), "The structured output should be RulesKeepOrRegenerate when allow_regen_graph is True."
                assert rules_keep_regen.regenerate_graph is not None
                return rules_keep_regen.regenerate_graph
            
        def assistant_summary_function_rules(output: RulesKeepOrRegenerate | RulesJsonOrKeep) -> AssistantSummary:
            if isinstance(output, RulesKeepOrRegenerate) and output.regenerate_graph is not None:
                return "graph_regenerate_from_rules"
            if output.keep_flag is not None:
                return "rules_keep"
            else:
                return "rules_proposed"

        if allow_regen_graph:
            from prompt_generate.prompts.prompts_zero_example import message_regenerate_graph_allow
        else:
            message_regenerate_graph_allow = ""

        from prompt_generate.prompts.prompts_zero_example import message_self_check_again, message_self_check_rules
        
        if rules is None:
            from prompt_generate.prompts.prompts_zero_example import message_analyze_rules

            if self.remind_last:
                from prompt_generate.prompts.prompts_zero_example import remind_factors_and_graph
                message_analyze_rules_user = remind_factors_and_graph.format(
                    scenario=self.scenario,
                    factors = factors.model_dump_json(indent=2),
                    graph=groundtruth_graph.model_dump_json(indent=2)
                )
            else:
                message_analyze_rules_user = ""

            rules_keep_regen = self.request_and_record(
                new_system_message=message_analyze_rules + message_regenerate_graph_allow,
                system_summary="system_start",
                new_user_message=message_analyze_rules_user,
                user_summary="rules_start",
                assistant_summary_function = assistant_summary_function_rules,
                structured_output=RulesKeepOrRegenerate if allow_regen_graph else RulesJsonOrKeep
            )

            if (regen_graph := sublogic_of_regen(rules_keep_regen)) is not None:
                # If the graph should be regenerated, break the loop and give the flag to regenerate the graph
                return None, regen_graph

            assert rules_keep_regen.rules_json is not None, "At the first step, the rules json should be provided."
            assert rules_keep_regen.keep_flag is None, "At the first step, the rules json should be provided."
            rules = rules_keep_regen.rules_json
        else:
            pass # directly use the provided rules as the initial rules.

        # 3.1 Check the rules
        next_time_short = False
        for i in range(self.self_check_times):
            pass_check, message_check = self.check_json_and_rule(rules, groundtruth_graph)
            if not pass_check:
                assert message_check is not None
                message_check += "Please regenerate your json file in the format defined above."
                rules_keep_regen = self.request_and_record(
                    new_user_message = message_check + message_regenerate_graph_allow,
                    user_summary="rules_rule_check_not_pass",
                    assistant_summary_function=assistant_summary_function_rules,
                    structured_output=RulesKeepOrRegenerate if allow_regen_graph else RulesJsonOrKeep
                    )
                next_time_short = False # A flag to show whether the next time the self-check prompt should be short. It should be True if the last check is passed.
            else:
                # check some vague standard, which cannot be checked by the check_json_and_rule function
                rules_keep_regen = self.request_and_record(
                    new_user_message=(message_self_check_rules if not next_time_short else message_self_check_again.format(object="json file")) + message_regenerate_graph_allow,
                    user_summary="rules_self_check",
                    assistant_summary_function=assistant_summary_function_rules,
                    structured_output=RulesKeepOrRegenerate if allow_regen_graph else RulesJsonOrKeep
                    )
                next_time_short = True

            if (regen_graph := sublogic_of_regen(rules_keep_regen)) is not None:
                # If the graph should be regenerated, break the loop and give the flag to regenerate the graph
                return None, regen_graph

            if rules_keep_regen.keep_flag is not None:
                # If the rules is good enough, break the loop and use the current rules.
                break
            else:
                assert rules_keep_regen.rules_json is not None, "The rules json should be provided when the keep_flag is None."
                rules = rules_keep_regen.rules_json
        
        assert rules is not None

        return rules, None

    def run_factors(self, generated_factors: FactorList | None) -> FactorList:
        """
        If the generated_factors is None, generate the factors first. Otherwise, directly use the provided factors as the initial list.
        """
        
        def assistant_summary_function_factors(output: FactorListOrKeep) -> AssistantSummary:
            if output.keep_flag is not None:
                return "factors_keep"
            else:
                return "factors_proposed"

        from prompt_generate.prompts.prompts_zero_example import message_user, message_self_check_factors, message_self_check_again
        # 1. The first request: analyze the possible factors
        if generated_factors is None:
            first_returned_factors = self.request_and_record(
                new_user_message=message_user.format(scenario=self.scenario),
                user_summary="factors_start",
                structured_output=FactorListOrKeep,
                assistant_summary_function=assistant_summary_function_factors,
                )
            assert first_returned_factors.factor_list is not None, "At the first step, the factor list should be provided."
            factor_list = first_returned_factors.factor_list
        else:
            factor_list = generated_factors

        # 1.1 Self-check about factors
        for i in range(self.self_check_times):
            if i == 0:
                feedback = self.request_and_record(
                    message_self_check_factors,
                    user_summary="factors_self_check",
                    assistant_summary_function=assistant_summary_function_factors,
                    structured_output=FactorListOrKeep
                    )
            else:
                feedback = self.request_and_record(
                    message_self_check_again.format(object="factor list"),
                    user_summary="factors_self_check",
                    assistant_summary_function=assistant_summary_function_factors,
                    structured_output=FactorListOrKeep
                    )
            if feedback.keep_flag is not None:
                factor_list = factor_list
                assert factor_list is not None, "The factor list should not be None when the keep_flag is not None."
                break
            else: # keep_flag is None and the factor list should be provided.
                factor_list = feedback.factor_list if feedback.factor_list is not None else factor_list
        assert factor_list is not None, "The factor list should not be None."
        return factor_list

    def run(self) -> None:
        generated_factors = None # the factor list is always generated from scratch.
        generated_graph = None
        json_reply = None
        
        generation_done_flag = False  # A flag marking whether factor generation is done (no regeneration)
        for factor_regen_time in range(self.max_regen_factors_time + 1):
            # 1. Generate the factors
            factor_list: FactorList = self.run_factors(generated_factors=generated_factors)
            assert isinstance(factor_list, FactorList), "The factor list should be provided when the factors are not regenerated."

            # 2. Generate the causal graph based on the factors
            allow_regen_factors = self.allow_regen_factors and (factor_regen_time != self.max_regen_factors_time)
            generated_graph = None # the graph is always generated from scratch.

            logger.info(f"Start to generate the graph for sample {self.idx} (time {self.time}), factor_regen_time {factor_regen_time} with max_regen_factors {self.max_regen_factors_time}.")
            for graph_regen_time in range(self.max_regen_graph_time + 1):
                generated_graph_and_rules = self.run_graph(
                    allow_regen_factors=allow_regen_factors,
                    factor_list=factor_list,
                    generated_graph=generated_graph
                    )
                if isinstance(generated_graph_and_rules, FactorList):
                    # If the factors should be regenerated, break the loop and regenerate the factors.
                    generated_factors = generated_graph_and_rules
                    logger.info(f"The factors of sample {self.idx} (time {self.time}) are regenerated because the graph generation step indicates to regenerate the factors.")
                    break # break the graph regeneration loop and regenerate the factors.
                if generated_graph_and_rules is None:
                    # If the graph should be regenerated, break the loop and regenerate the graph.
                    generated_graph = None
                    continue
                assert len(generated_graph_and_rules) == 2, "The return of run_graph should be a tuple of (Graph, RulesJson)."
                generated_graph, json_reply = generated_graph_and_rules
                assert isinstance(generated_graph, Graph), "The graph should be provided when the factors are not regenerated."
                assert isinstance(json_reply, RulesJson), "The rules json should be provided when the factors are not regenerated."
                generation_done_flag = True  # Fix: mark the flag
                break # break the graph regeneration loop and do not regenerate the factors.
            if generation_done_flag:
                break  # All the tests passed. Exit.
        # 5. Save the final answer
        if json_reply is None or generated_graph is None:
            raise RuntimeError("The final graph and rules json should be provided after the run function.")
        self.save_result(
            graph=generated_graph,
            rules_json=json_reply
        )

class PromptGenerator:

    def __init__(self, model: str, scenario_file: str, cache_dir: str, output_dir: str, graph_dir: str, graph_img_dir: str | None, graph_img_type: str, rules_dir: str, self_check_times: int = 0, self_example_times: int = 0, preview: bool = False, allow_regen_factors: bool = False, max_regen_factors_time: int = 3, allow_regen_graph: bool = False, max_regen_graph_time: int = 3, complex: Literal['low', 'medium', 'high'] = 'low', max_nodes: int | None = None, reasoning_effort: Literal['minimal', 'low', 'medium', 'high'] = "minimal", max_output_tokens: int = 12800, remind_last: bool = False, split_steps: bool = True) -> None:
        self.model = model
        self.scenario_file = scenario_file
        self.scenarios = self.load_scenarios()
        self.cache_dir = os.path.abspath(cache_dir)
        self.output_dir = os.path.abspath(output_dir)
        self.graph_dir = os.path.abspath(graph_dir)
        self.graph_img_dir = os.path.abspath(graph_img_dir) if graph_img_dir is not None else None
        self.graph_img_type = graph_img_type
        self.rules_dir = os.path.abspath(rules_dir)
        self.self_check_times = self_check_times
        self.self_example_times = self_example_times
        self.preview = preview
        self.allow_regen_factors = allow_regen_factors if max_regen_factors_time > 0 else False
        self.max_regen_factors_time = max_regen_factors_time
        self.allow_regen_graph = allow_regen_graph if max_regen_graph_time > 0 else False
        self.max_regen_graph_time = max_regen_graph_time
        self.complex = complex
        self.max_nodes = max_nodes
        self.reasoning_effort = reasoning_effort
        self.max_output_tokens = max_output_tokens
        self.remind_last = remind_last
        self.split_steps = split_steps
        if self.split_steps and not self.remind_last:
            logger.warning('The `remind_last` flag must be True when the `split_steps` flag is True. Set the `remind_last` flag to True.')
            self.remind_last = True
        if not split_steps and self.remind_last:
            logger.warning("The `remind_last` flags is set to True when `split_steps` is False, which may lead to the increase of the token usage and cost. Only use it when the generation quality is not good enough.")

    def load_scenarios(self) -> list[str]:
        with open(self.scenario_file, 'r') as f:
            lines = f.readlines()
        logger.info("Load {} scenarios from {}".format(len(lines), os.path.abspath(self.scenario_file)))
        return lines

    def generate_prompt_zero_example(self, input: tuple[int, bool, int]):
        idx = input[0]
        force_regen = input[1]
        time = input[2]
        scenario = self.scenarios[idx]
        file_name = f"{idx}_{time}.json"
        if force_regen:
            logger.info("`force_regen` has been set to True, any existing results will be covered.")
        if not force_regen and (output_file:=os.path.exists(os.path.join(self.output_dir, file_name))):
            logger.info(f"The generation has existed at {output_file}; skip the generation of scenario \"{scenario}\" (index {idx} time {time}).")
            return
        if not force_regen and os.path.exists(os.path.join(self.cache_dir, file_name)):
            try:
                with open(os.path.join(self.cache_dir, file_name), 'r') as f:
                    cache = json.load(f)
                    logger.info(f"Load the cached file {os.path.join(self.cache_dir, file_name)}.")
            except json.JSONDecodeError as e:
                logger.warning(f"The cached file {os.path.join(self.cache_dir, file_name)} is not a valid json file. The error is {e}. Regenerate the prompt from scratch.")
                cache = None
            if not (cache is not None and cache.get("model") == self.model and cache.get("scenario") == scenario):
                logger.warning(f"The cached file {os.path.join(self.cache_dir, file_name)} has different setting from the current setting. Regenerate the prompt from scratch.")
                cache = None
        else:
            cache = None

        generator = SinglePromptGeneratorNoExample(
            model = self.model,
            scenario = scenario,
            idx = idx,
            time = time,
            cache_file = os.path.join(self.cache_dir, f"{idx}_{time}.json"),
            output_file = os.path.join(self.output_dir, f"{idx}_{time}.json"),
            graph_file = os.path.join(self.graph_dir, f"{idx}_{time}.json"),
            graph_img_file = os.path.join(self.graph_img_dir, f"{idx}_{time}.{self.graph_img_type}") if self.graph_img_dir is not None else None,
            graph_img_type = self.graph_img_type,
            rule_file = os.path.join(self.rules_dir, f"{idx}_{time}.json"),
            self_check_times = self.self_check_times,
            preview = self.preview,
            allow_regen_factors= self.allow_regen_factors if self.max_regen_factors_time > 0 else False,
            max_regen_factors_time = self.max_regen_factors_time,
            allow_regen_graph= self.allow_regen_graph if self.max_regen_graph_time > 0 else False,
            max_regen_graph_time = self.max_regen_graph_time,
            complex = self.complex,
            max_nodes = self.max_nodes,
            reasoning_effort=self.reasoning_effort,
            max_output_tokens=self.max_output_tokens,
            remind_last=self.remind_last,
            split_steps=self.split_steps,  # Fix: specify the value of the argument
        )
        if cache is not None:
            generator.load_cache(cache)
        generator.run()

    def generate(self, idxes: Iterable[int] | int | None, worker_num: int, times: int = 1, force_regen: bool = False):
        if idxes is None:
            idxes_times = sum([[(i, t) for t in range(times)] for i in range(len(self.scenarios))], start=[])
        elif isinstance(idxes, int):
            idxes_times = [(idxes, t) for t in range(times)]
        else:
            idxes_times = sum([[(i, t) for t in range(times)] for i in idxes], start=[])
        inputs = [(idx, force_regen, times) for idx, times in idxes_times]
        from multiprocessing import Pool
        with Pool(processes=worker_num) as pool:
            pool.map(self.generate_prompt_zero_example, inputs)

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="Generate the prompt for the causal graph generation task.")
    parser.add_argument("-m", "--model", type=str, default="o1")
    parser.add_argument("-i", "--indexes", type=int, nargs="*")
    parser.add_argument("-t", "--times", type=int, default=1, help="Generate how many times for each scenario.")
    parser.add_argument("-s", "--scenario_file", type=str, default="dataset/scenarios.txt")
    parser.add_argument("-c", "--cache_dir", type=str, default="dataset/cache")
    parser.add_argument("-o", "--output_dir", type=str, default="dataset/output")
    parser.add_argument("-g", "--graph_dir", type=str, default="dataset/graph")
    parser.add_argument("--graph_img_dir", type=str, default="dataset/graph_img", help='path to save visualization of graphs. If "no", do not save the visualization.')
    parser.add_argument("--graph_img_type", type=str, default="jpg", choices=["bmp", "cgimage", "eps", "exr", "fig", "gd", "gif", "ico", "jpg", "jpeg", "jpe", "jp2", "pdf", "png", "svg", "svgz"], help='The type of the image file.')
    parser.add_argument("-r", "--rules_dir", type=str, default="dataset/rules")
    parser.add_argument("--self_check_times", type=int, default=3)
    parser.add_argument("--self_example_times", type=int, default=3)
    parser.add_argument("-n", "--worker_num", type=int, default=1)
    parser.add_argument("-f", "--force_regen", action="store_true", help="Force to regenerate the prompt, otherwise skip the existed prompt.")
    parser.add_argument("--preview", action="store_true", help="if the openai is the preview version")
    parser.add_argument("--allow_regen_factors", action="store_true", help="if allow to regenerate the factors")
    parser.add_argument("--max_regen_factors_time", type=int, default=1, help="The maximum times to regenerate the factors.")
    parser.add_argument("--allow_regen_graph", action="store_true", help="if allow to regenerate the graph")
    parser.add_argument("--max_regen_graph_time", type=int, default=1, help="The maximum times to regenerate the graph.")
    parser.add_argument("--complex", type=str, default="low", choices=["low", "medium", "high"], help="The complexity level of the generated graph. Vaguely control the number of nodes.")
    parser.add_argument("--max_nodes", type=int, default=None,help="Hard upper bound on total nodes (factors + results) in the DAG.")
    parser.add_argument("--reasoning_effort", type=str, default="minimal",choices=["minimal", "low", "medium", "high"],help="Reasoning effort level to pass to the API.")
    parser.add_argument("--max_output_tokens", type=int, default=12800, help="The maximum tokens for the output of the model.")
    parser.add_argument("--remind_last", action="store_true", help="if remind the last factors and graphs to the model when generating the graph and rules. This may increase the token usage and cost.")
    parser.add_argument("--split_steps", action="store_true", help="if True, the history  of generating factors (graph) will be ignored when generating graph (rules). It will reduce the reading token usage and cost.")
    
    return parser.parse_args()

def main():
    args = parse_args()
    logger.setLevel('DEBUG')
    generator = PromptGenerator(
        model = args.model,
        scenario_file = args.scenario_file,
        cache_dir = args.cache_dir,
        output_dir = args.output_dir,
        graph_dir = args.graph_dir,
        graph_img_dir = args.graph_img_dir if args.graph_img_dir != 'no' else None,
        graph_img_type = args.graph_img_type,
        rules_dir = args.rules_dir,
        self_check_times = args.self_check_times,
        self_example_times = args.self_example_times,
        preview = args.preview,
        allow_regen_factors = args.allow_regen_factors if args.max_regen_factors_time > 0 else False,
        max_regen_factors_time = args.max_regen_factors_time,
        allow_regen_graph = args.allow_regen_graph if args.max_regen_graph_time > 0 else False,
        max_regen_graph_time = args.max_regen_graph_time,
        complex = args.complex,
        max_nodes=args.max_nodes,
        reasoning_effort=args.reasoning_effort,
        max_output_tokens=args.max_output_tokens,
        remind_last=args.remind_last,
        split_steps=args.split_steps,
    )
    generator.generate(idxes=args.indexes, worker_num=args.worker_num, force_regen=args.force_regen, times=args.times)

if __name__ == '__main__':
    main()