from typing import Dict
from ..base import ModuleBase

class PostProcessorBase(ModuleBase):

    def __init__(self, name: str, max_retries: int = 3):
        super().__init__(name)
        self.max_retries = max_retries  
        
    async def process_sql_with_retry(self, sql: str, query_id: str = None) -> str:
        last_error = None
        for attempt in range(self.max_retries):
            try:
                raw_processed_sql = await self.process_sql(sql, query_id)
                if raw_processed_sql is not None:
                    processed_sql = self.extractor.extract_sql(raw_processed_sql)
                    if processed_sql is not None:
                        return processed_sql
            except Exception as e:
                last_error = e
                self.logger.warning(f"SQL postprocessing failed at the{attempt + 1}/{self.max_retries} times attempt: {str(e)}")
                continue

        self.logger.error(f"SQL post-processing fails after a total of {self.max_retries} attempts, using the original SQL. The last error: {str(last_error)}。Question ID: {query_id} Program continuation...")

        try:
            prev_result = self.load_previous_result(query_id)
            original_query = prev_result["input"]["query"]
        except:
            original_query = "Unknown"

        self.save_intermediate(
            input_data={
                "original_query": original_query,
                "original_sql": sql
            },
            output_data={
                "raw_output": "Post-processing failed, using original SQL",
                "processed_sql": sql  
            },
            model_info={
                "model": "none",
                "input_tokens": 0,
                "output_tokens": 0,
                "total_tokens": 0
            },
            query_id=query_id
        )
        
        return sql 