import asyncio
from src.core.llm import LLMBase
from src.modules.schema_linking.basic_linker import BasicSchemaLinker
from src.modules.schema_linking.enhanced_linker import EnhancedSchemaLinker
from src.modules.schema_linking.skip_linker import SkipSchemaLinker
from src.modules.sql_generation.gpt_generator import GPTSQLGenerator

from src.modules.sql_generation.enhanced_generator import EnhancedSQLGenerator

from src.modules.sql_generation.dc_refiner_generator import DCRefinerSQLGenerator

from src.modules.post_processing.reflection import ReflectionPostProcessor
from src.modules.post_processing.feedback_based_reflection import FeedbackBasedReflectionPostProcessor
from src.modules.post_processing.skip_post_processing import SkipPostProcessor
from src.pipeline import ElephantSQLPipeline
from concurrent.futures import ThreadPoolExecutor

async def main(backbone_model: str = 'gpt-4o-mini-2024-07-18'):

    llm = LLMBase()

    pipeline_v = ElephantSQLPipeline(
        schema_linker=EnhancedSchemaLinker(
            llm, 
            model=backbone_model, 
            temperature=0.0, 
            max_tokens=10000,
            max_retries=10
        ),
        sql_generator=GPTSQLGenerator(
            llm, 
            model=backbone_model, 
            temperature=0.0, 
            max_tokens=10000,
            max_retries=10
        ),
        post_processor=SkipPostProcessor()
    )

    pipeline_enh = ElephantSQLPipeline(
        schema_linker=EnhancedSchemaLinker(
            llm, 
            model=backbone_model, 
            temperature=0.0, 
            max_tokens=10000,
            max_retries=10
        ),
        sql_generator=EnhancedSQLGenerator(
            llm, 
            model=backbone_model, 
            temperature=0.0, 
            max_tokens=10000,
            max_retries=10
        ),
        post_processor=SkipPostProcessor()
    )

    pipeline_dcr = ElephantSQLPipeline(
        schema_linker=EnhancedSchemaLinker(
            llm, 
            model=backbone_model, 
            temperature=0.0, 
            max_tokens=10000,
            max_retries=10

        ),
        sql_generator=DCRefinerSQLGenerator(
            llm, 
            model=backbone_model, 
            temperature=0.0, 
            max_tokens=10000,
            max_retries=10
        ),
        post_processor=SkipPostProcessor()
    )

    await pipeline_v.run_pipeline_parallel(
        data_file="./data/formatted_bird_dev.json", 
        max_workers=100
    )

if __name__ == "__main__":
    policy = asyncio.get_event_loop_policy()
    policy.set_event_loop(policy.new_event_loop())

    loop = asyncio.get_event_loop()
    loop.set_default_executor(ThreadPoolExecutor(max_workers=300))

    loop.run_until_complete(main(backbone_model="claude-3-haiku-20240307"))
    loop.close() 