from typing import Dict
from .base import SchemaLinkerBase
from ...core.llm import LLMBase
from .prompts.schema_prompts import ENHANCED_SCHEMA_SYSTEM, BASE_SCHEMA_USER
from ...core.schema.manager import SchemaManager
from ...core.utils import load_json
import os
import json

class EnhancedSchemaLinker(SchemaLinkerBase):
    def __init__(self, 
                llm: LLMBase, 
                model: str = "gpt-3.5-turbo-0613",
                temperature: float = 0.0,
                max_tokens: int = 1000,
                max_retries: int = 3):
        super().__init__("EnhancedSchemaLinker", max_retries)
        self.llm = llm
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.schema_manager = SchemaManager()
        
    def _validate_linked_schema(self, linked_schema: Dict, database_schema: Dict) -> bool:
        try:
            db_tables = {table["table"]: set(table["columns"].keys()) 
                        for table in database_schema["tables"]}

            for table in linked_schema["tables"]:
                table_name = table["table"]

                if table_name not in db_tables:
                    self.logger.warning(f"table '{table_name}' at primary schema not exist")
                    return False

                for col in table["columns"]:
                    if col not in db_tables[table_name]:
                        self.logger.warning(
                            f"column '{col}' at table '{table_name}' not exist"
                        )
                        return False

                if "primary_keys" in table:
                    for pk in table["primary_keys"]:
                        if pk not in db_tables[table_name]:
                            self.logger.warning(
                                f"primary key '{pk}' at table '{table_name}' not exist"
                            )
                            return False

                if "foreign_keys" in table:
                    for fk in table["foreign_keys"]:
                        if fk["column"] not in db_tables[table_name]:
                            self.logger.warning(
                                f"foreign key column '{fk['column']}' at table '{table_name}' not exist"
                            )
                            return False

                        ref_table = fk["referenced_table"]
                        ref_column = fk["referenced_column"]
                        
                        if ref_table not in db_tables:
                            self.logger.warning(
                                f"referenced table '{ref_table}' at primary schema not exist"
                            )
                            return False
                            
                        if ref_column not in db_tables[ref_table]:
                            self.logger.warning(
                                f"referenced column '{ref_column}' at table '{ref_table}' not exist"
                            )
                            return False
                        
            return True
            
        except Exception as e:
            self.logger.error(f"An error occurred while validating linked schema: {str(e)}")
            return False
            
    def _enhance_linked_schema(self, linked_schema: Dict, database_schema: Dict) -> Dict:
        db_tables = {table["table"]: table for table in database_schema["tables"]}

        linked_table_names = [table["table"] for table in linked_schema["tables"]]

        for table in linked_schema["tables"]:
            table_name = table["table"]
            db_table = db_tables[table_name]

            if db_table["primary_keys"]:
                if "primary_keys" not in table:
                    table["primary_keys"] = db_table["primary_keys"]

                for pk in db_table["primary_keys"]:
                    if pk not in table["columns"]:
                        table["columns"].append(pk)
                        self.logger.info(f"for table '{table_name}' add primary key column '{pk}'")

        if len(linked_table_names) > 1: 
            for fk in database_schema.get("foreign_keys", []):
                src_table = fk["table"][0]
                dst_table = fk["table"][1]
                src_col = fk["column"][0]
                dst_col = fk["column"][1]
                
                if src_table in linked_table_names and dst_table in linked_table_names:
                    src_table_obj = next(t for t in linked_schema["tables"] if t["table"] == src_table)
                    dst_table_obj = next(t for t in linked_schema["tables"] if t["table"] == dst_table)
                    
                    for table_obj, col_name in [(src_table_obj, src_col), (dst_table_obj, dst_col)]:
                        if col_name not in table_obj["columns"]:
                            table_obj["columns"].append(col_name)
                            table_name = table_obj["table"]
                            if table_obj == src_table_obj:
                                self.logger.info(f"for table '{src_table}' add foreign key column '{src_col}'")
                            else:
                                self.logger.info(f"for table '{dst_table}' add referenced column '{dst_col}'")

                    if "foreign_keys" not in src_table_obj:
                        src_table_obj["foreign_keys"] = []

                    fk_exists = any(
                        fk["column"] == src_col and 
                        fk["referenced_table"] == dst_table and 
                        fk["referenced_column"] == dst_col 
                        for fk in src_table_obj.get("foreign_keys", [])
                    )
                    
                    if not fk_exists:
                        src_table_obj["foreign_keys"].append({
                            "column": src_col,
                            "referenced_table": dst_table,
                            "referenced_column": dst_col
                        })
                        self.logger.info(
                            f"Add a foreign key relationship: {src_table}.{src_col} -> {dst_table}.{dst_col}"
                        )

        for table in linked_schema["tables"]:
            table_name = table["table"]
            db_table = db_tables[table_name]

            columns_info = {}
            for col in table["columns"]:  
                if col in db_table["columns"]:
                    col_info = db_table["columns"][col]
                    columns_info[col] = {
                        "type": col_info["type"],
                        "expanded_name": col_info.get("expanded_name", ""),
                        "description": col_info.get("description", ""),
                        "data_format": col_info.get("data_format", ""),
                        "value_description": col_info.get("value_description", ""),
                        "value_examples": col_info.get("value_examples", [])
                    }
            table["columns_info"] = columns_info
        
        return linked_schema
    
    def enhance_schema_only_with_keys(self, linked_schema: Dict, database_schema: Dict) -> Dict:

        enhanced_schema = {"tables": []}

        table_info = {}
        for table in database_schema["tables"]:
            table_info[table["table"]] = {
                "primary_keys": table.get("primary_keys", []),
                "foreign_keys": table.get("foreign_keys", [])
            }

        for table in linked_schema.get("tables", []):
            table_name = table["table"]
            columns = set(table["columns"]) 
            
            if table_name in table_info:

                for pk in table_info[table_name]["primary_keys"]:
                    columns.add(pk)
                
                
                for fk in table_info[table_name]["foreign_keys"]:
                    columns.add(fk["column"])
            
          
            enhanced_table = {
                "table": table_name,
                "columns": list(columns)  
            }
            
        
            if table_name in table_info:
                if table_info[table_name]["primary_keys"]:
                    enhanced_table["primary_keys"] = table_info[table_name]["primary_keys"]
                
               
                foreign_keys = []
                for fk in table_info[table_name]["foreign_keys"]:
                    if fk["column"] in columns:
                        foreign_keys.append(fk)
                if foreign_keys:
                    enhanced_table["foreign_keys"] = foreign_keys
            
            enhanced_schema["tables"].append(enhanced_table)
            
        return enhanced_schema
    
    async def link_schema(self, query: str, database_schema: Dict, query_id: str = None) -> str:

        schema_str = self.schema_manager.format_enriched_db_schema(database_schema)

        data_file = self.data_file
        dataset_examples = load_json(data_file)
        curr_evidence = ""
        for item in dataset_examples:
            if(item.get("question_id") == query_id):
                curr_evidence = item.get("evidence", "")
                break
        
        self.logger.debug("Formatted Schema information:")
        self.logger.debug("\n" + schema_str)
        
        messages = [
            {"role": "system", "content": ENHANCED_SCHEMA_SYSTEM},
            {"role": "user", "content": BASE_SCHEMA_USER.format(
                schema_str=schema_str,
                query=query,
                evidence=curr_evidence if curr_evidence else "None"
            )}
        ]
        
        result = await self.llm.call_llm(
            messages,
            self.model,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            module_name=self.name
        )
        
        raw_output = result["response"]
        extracted_linked_schema = self.extractor.extract_schema_json(raw_output)

        if not extracted_linked_schema or not self._validate_linked_schema(extracted_linked_schema, database_schema):
            raise ValueError("Schema linking result validation failure: contains table or column that does not exist")

        enhanced_linked_schema = self._enhance_linked_schema(extracted_linked_schema, database_schema)

        enhanced_linked_schema_wo_info = self.enhance_schema_only_with_keys(extracted_linked_schema, database_schema)

        formatted_linked_schema = self.schema_manager.format_linked_schema(enhanced_linked_schema)

        self.save_intermediate(
            input_data={
                "query": query, 
            },
            output_data={
                "raw_output": raw_output,
                "extracted_linked_schema": extracted_linked_schema,
                "enhanced_linked_schema": enhanced_linked_schema,  
                "enhanced_linked_schema_wo_info": enhanced_linked_schema_wo_info,
                "formatted_linked_schema": formatted_linked_schema
            },
            model_info={
                "model": self.model,
                "input_tokens": result["input_tokens"],
                "output_tokens": result["output_tokens"],
                "total_tokens": result["total_tokens"]
            },
            query_id=query_id
        )
        
        self.log_io(
            {
                "query": query, 
                "database_schema": database_schema, 
                "formatted_schema": schema_str,
                "messages": messages
            }, 
            raw_output
        )

        source = database_schema.get("source", "unknown")
        self.save_linked_schema_result(
            query_id=query_id,
            source=source,
            linked_schema={
                "database": database_schema.get("database", ""),
                "tables": enhanced_linked_schema_wo_info.get("tables", [])
            }
        )
        
        return raw_output 
