from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from ..base import ModuleBase
import json
import os

class SchemaLinkerBase(ModuleBase):

    def __init__(self, name: str, max_retries: int = 3):
        super().__init__(name)
        self.max_retries = max_retries  
        
    def enrich_schema_info(self, linked_schema: Dict, database_schema: Dict) -> Dict:
        db_tables = {table["table"]: table for table in database_schema["tables"]}
        
        for table in linked_schema["tables"]:
            table_name = table["table"]
            if table_name in db_tables:
                if "primary_keys" in db_tables[table_name]:
                    table["primary_keys"] = db_tables[table_name]["primary_keys"]
                
        if "foreign_keys" in database_schema:
            for fk in database_schema["foreign_keys"]:
                src_table = fk["table"][0]
                dst_table = fk["table"][1]
                linked_tables = [t["table"] for t in linked_schema["tables"]]
                
                if src_table in linked_tables and dst_table in linked_tables:
                    for table in linked_schema["tables"]:
                        if table["table"] == src_table:
                            if "foreign_keys" not in table:
                                table["foreign_keys"] = []
                            
                            table["foreign_keys"].append({
                                "column": fk["column"][0],
                                "referenced_table": dst_table,
                                "referenced_column": fk["column"][1]
                            })
                            
        return linked_schema
    
    @abstractmethod
    async def link_schema(self, 
                       query: str, 
                       database_schema: Dict) -> Dict:
        pass
    
    async def link_schema_with_retry(self, query: str, database_schema: Dict, query_id: str = None) -> str:
        last_error = None
        for attempt in range(self.max_retries):
            try:
                raw_schema_output = await self.link_schema(query, database_schema, query_id)
                if raw_schema_output is not None:
                    extracted_linked_schema = self.extractor.extract_schema_json(raw_schema_output)
                    enriched_linked_schema = self.enrich_schema_info(extracted_linked_schema, database_schema)
                    if extracted_linked_schema is not None:
                        return enriched_linked_schema
            except Exception as e:
                last_error = e
                self.logger.warning(f"Schema linking failed after {attempt + 1}/{self.max_retries}times attempt: {str(e)}")
                continue

        self.logger.error(f"Schema linking failed after{self.max_retries}times attempt, return complete database schema。Last error: {str(last_error)}。Question ID: {query_id} Program continuation...")
        full_schema = {
            "tables": [
                {
                    "table": table["table"],
                    "columns": list(table["columns"].keys()),
                    "columns_info": table["columns"],
                    "primary_keys": table.get("primary_keys", [])
                }
                for table in database_schema["tables"]
            ]
        }

        if "foreign_keys" in database_schema:
            for table in full_schema["tables"]:
                table_name = table["table"]
                table["foreign_keys"] = []
                for fk in database_schema["foreign_keys"]:
                    if fk["table"][0] == table_name:
                        table["foreign_keys"].append({
                            "column": fk["column"][0],
                            "referenced_table": fk["table"][1],
                            "referenced_column": fk["column"][1]
                        })

        self.save_intermediate(
            input_data={
                "query": query,
            },
            output_data={
                "raw_output": "Schema linking failed, using full database schema",
                "extracted_linked_schema": full_schema,
                "formatted_linked_schema": self.schema_manager.format_linked_schema(full_schema)
            },
            model_info={
                "model": "none",
                "input_tokens": 0,
                "output_tokens": 0,
                "total_tokens": 0
            },
            query_id=query_id
        )

        source = database_schema.get("source", "unknown")
        self.save_linked_schema_result(
            query_id=query_id,
            source=source,
            linked_schema={
                "database": database_schema.get("database", ""),
                "tables": full_schema.get("tables", [])
            }
        )
        return full_schema
    
    def _format_basic_schema(self, schema: Dict) -> str:
        result = []

        result.append(f"Database: {schema['database']}\n")

        for table in schema['tables']:
            result.append(f"Table name: {table['table']}")
            result.append(f"Columns: {', '.join(table['columns'])}")
            if table['primary_keys']:
                result.append(f"Primary keys: {', '.join(table['primary_keys'])}")
            result.append("")

        if schema.get('foreign_keys'):
            result.append("Foreign keys:")
            for fk in schema['foreign_keys']:
                result.append(
                    f"  {fk['table'][0]}.{fk['column'][0]} = "
                    f"{fk['table'][1]}.{fk['column'][1]}"
                )
                
        return "\n".join(result) 
    
    def save_linked_schema_result(self, query_id: str, source: str, linked_schema: Dict) -> None:
        pipeline_dir = self.intermediate.pipeline_dir
        
        output_file = os.path.join(pipeline_dir, "linked_schema_results.jsonl")

        result = {
            "query_id": query_id,
            "source": source,
            "database": linked_schema.get("database", ""),
            "tables": [
                {
                    "table": table["table"],
                    "columns": table["columns"]
                }
                for table in linked_schema.get("tables", [])
            ]
        }

        with open(output_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(result, ensure_ascii=False) + "\n") 