from typing import Dict, List, Optional
import re

import json

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line.strip()))
    return data

class TextExtractor:
    @staticmethod
    def extract_sub_questions(text: str) -> List[str]:
        pattern = r"<<(.*?)>>"
        sub_questions = re.findall(pattern, text, re.DOTALL)
        return sub_questions
    
    @staticmethod
    def extract_sql(text: str) -> str:
        sql_pattern = r"```sql\s*(.*?)\s*```"
        matches = re.findall(sql_pattern, text, re.DOTALL)
        return matches[-1].strip() if matches else text.strip()
    
    @staticmethod
    def extract_code_block(text: str, language: str = None) -> str:
        if language:
            pattern = f"```{language}\\s*(.*?)\\s*```"
        else:
            pattern = r"```\w*\s*(.*?)\s*```"
            
        matches = re.findall(pattern, text, re.DOTALL)
        return matches[-1].strip() if matches else text.strip()
        
    @staticmethod
    def extract_schema_json(text: str) -> Optional[Dict]:
        json_pattern = r"```json\s*(.*?)\s*```"
        matches = re.findall(json_pattern, text, re.DOTALL)
        
        if not matches:
            try:
                schema = eval(text.strip())
            except:
                return None
        else:
            try:
                schema = eval(matches[0].strip())
            except:
                return None

        if not isinstance(schema, dict) or "tables" not in schema:
            return None
        if not isinstance(schema["tables"], list):
            return None

        for table in schema["tables"]:
            if not isinstance(table, dict):
                return None
            if not all(k in table for k in ["table", "columns"]):
                return None
            if not isinstance(table["table"], str):
                return None
            if not isinstance(table["columns"], list):
                return None
            if not all(isinstance(col, str) for col in table["columns"]):
                return None

            if "primary_keys" in table:
                if not isinstance(table["primary_keys"], list):
                    return None
                if not all(isinstance(pk, str) for pk in table["primary_keys"]):
                    return None
                    
            if "foreign_keys" in table:
                if not isinstance(table["foreign_keys"], list):
                    return None
                for fk in table["foreign_keys"]:
                    if not isinstance(fk, dict):
                        return None
                    if not all(k in fk for k in ["column", "referenced_table", "referenced_column"]):
                        return None
                    if not all(isinstance(fk[k], str) for k in ["column", "referenced_table", "referenced_column"]):
                        return None
                        
        return schema 