from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Optional
from ..base import ModuleBase
from ...core.sql_execute import validate_sql_execution
from ...core.utils import load_json
from ...core.config import Config
import os

class SQLGeneratorBase(ModuleBase):

    def __init__(self, name: str = "SQLGenerator", max_retries: int = 3):
        super().__init__(name)
        self.max_retries = max_retries
        
    @abstractmethod
    async def generate_sql(self,
                        query: str,
                        schema_linking_output: Dict,
                        module_name: Optional[str] = None) -> str:

        pass 
    
    async def generate_sql_with_retry(self, query: str, schema_linking_output: Dict, query_id: str, module_name: str = None) -> str:

        last_error = None
        last_extracted_sql = None 
        
        for attempt in range(self.max_retries):
            try:
                raw_sql_output = await self.generate_sql(query, schema_linking_output, query_id, module_name)
                if raw_sql_output is not None:
                    extracted_sql = self.extractor.extract_sql(raw_sql_output)
                    if extracted_sql is not None:
                        return extracted_sql
                    else:
                        self.logger.warning(f"Unable to extract SQL from output, first {attempt + 1}/{self.max_retries} attempt")
                        continue
            except Exception as e:
                last_error = e
                self.logger.warning(f"SQL generation failed at {attempt + 1}/{self.max_retries} times attempt: {str(e)}")
                continue

        error_message = f"SQL generation fails completely after {self.max_retries}, no valid SQL can be extracted. The last error: {str(last_error)}。Question ID: {query_id} Program continuation..."
        self.logger.error(error_message)

        first_table = schema_linking_output["original_schema"]["tables"][0]
        table_name = first_table["table"]
        first_column = list(first_table["columns"].keys())[0]
        fallback_sql = f"SELECT {first_column} FROM {table_name} LIMIT 1;"
        
        return fallback_sql
    
    async def generate_sql_with_retry_with_validate(self, query: str, schema_linking_output: Dict, query_id: str, module_name: str = None) -> str:

        last_error = None
        last_extracted_sql = None 

        if not self.data_file:
            raise ValueError("Data file path not set. Please call set_data_file() first.")
            
        data = load_json(self.data_file) 
        db_path = None
        
        for item in data:
            if item.get("question_id") == query_id:
                db_id = item.get("db_id", "")
                source = item.get("source", "")
                data_dir = os.path.dirname(self.data_file)
                db_folder = f"{source}_{db_id}"
                db_file = f"{db_id}.sqlite"
                db_path = str(Config().database_dir / db_folder / db_file)
                break
                
        if not db_path:
            raise ValueError(f"The database path for problem ID {query_id} could not be found")
            
        for attempt in range(self.max_retries):
            try:
                raw_sql_output = await self.generate_sql(query, schema_linking_output, query_id, module_name)
                if raw_sql_output is not None:
                    extracted_sql = self.extractor.extract_sql(raw_sql_output)
                    if extracted_sql is not None:
                        last_extracted_sql = extracted_sql  
                        is_valid, error_msg = validate_sql_execution(db_path, extracted_sql)
                        if is_valid:
                            return extracted_sql
                        else:
                            self.logger.warning(f"SQL validation failed: {error_msg}")
                            continue
            except Exception as e:
                last_error = e
                self.logger.warning(f"SQL generation failed at {attempt + 1}/{self.max_retries} times attempt: {str(e)}")
                continue

        if last_extracted_sql:
            error_message = f"SQL generation fails completely after {self.max_retries}, Use the last extracted SQL. The last error: {str(last_error)}。Question ID: {query_id} Program continuation..."
            self.logger.error(error_message)
            return last_extracted_sql
        else:
            error_message = f"SQL generation fails completely after {self.max_retries}, no valid SQL can be extracted. The last error: {str(last_error)}。Question ID: {query_id} Program continuation..."
            self.logger.error(error_message)

            first_table = schema_linking_output["original_schema"]["tables"][0]
            table_name = first_table["table"]
            first_column = list(first_table["columns"].keys())[0]
            fallback_sql = f"SELECT {first_column} FROM {table_name} LIMIT 1;"
            is_valid, error_msg = validate_sql_execution(db_path, fallback_sql)
            
            if not is_valid:
                self.logger.fatal(f"[FATAL] Fallback SQL execution failed: {error_msg}。The last error: {str(last_error)}。Question ID: {query_id}")
            
            return fallback_sql 