from typing import Dict, List
from ....core.llm import LLMBase
from ....core.sql_execute import *
from ....core.utils import load_json, load_jsonl
from ..prompts.divide_and_conquer_cot_prompts import SQL_GENERATION_SYSTEM, DIVIDE_PROMPT, CONQUER_PROMPT_WO_EXAMPLES, ASSEMBLE_PROMPT
from ....core.utils import TextExtractor
from ..base import ModuleBase
from typing import Any, Dict, Optional, Tuple, Callable



class DivideConqueror():
    def __init__(self, 
                llm: LLMBase, 
                model: str = "gpt-3.5-turbo-0613",
                temperature: float = 0.0,
                max_tokens: int = 5000,
                module_name: str = "EnhancedSQLGenerator"):
        self.llm = llm
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.module_name = module_name
        self.extractor = TextExtractor()


    async def generate_sql(self, query: str, formatted_schema: str, curr_evidence: str)-> str:
        divide_prompt = [
            {"role": "system", "content": SQL_GENERATION_SYSTEM},
            {"role": "user", "content": DIVIDE_PROMPT.format(
                schema=formatted_schema,
                query=query,
                evidence = curr_evidence if curr_evidence else "None"
            )}
        ]
        
        divide_result = await self.llm.call_llm(
            divide_prompt,
            self.model,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            module_name=self.module_name
        )

        input_tokens = divide_result["input_tokens"]
        output_tokens = divide_result["output_tokens"]
        total_tokens = divide_result["total_tokens"]

        
        raw_output = divide_result["response"]
        sub_questions = self.extractor.extract_sub_questions(raw_output)

        ssql = [] 

        for sub_question in sub_questions:

            conquer_prompt = [
                {"role": "system", "content": SQL_GENERATION_SYSTEM},
                {"role": "user", "content": CONQUER_PROMPT_WO_EXAMPLES.format(
                    schema=formatted_schema,
                    query=sub_question,
                    evidence = curr_evidence if curr_evidence else "None"
                )}
            ]

            conquer_result = await self.llm.call_llm(
                conquer_prompt,
                self.model,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                module_name=self.module_name
            )


            input_tokens += conquer_result["input_tokens"]
            output_tokens += conquer_result["output_tokens"]
            total_tokens += conquer_result["total_tokens"]

            raw_output = conquer_result["response"]

            ssql.append(raw_output)

        sub_prompt = ""
        for i in range(len(sub_questions)):
            sub_prompt += "Sub-question " + str(i) +": "
            sub_prompt += sub_questions[i] +"\n"
            sub_prompt += "SQL query " + str(i) +": "
            sub_prompt += ssql[i] +"\n\n"

    
        assemble_prompt = [
            {"role": "system", "content": SQL_GENERATION_SYSTEM},
            {"role": "user", "content": ASSEMBLE_PROMPT.format(
                schema=formatted_schema,
                query=query,
                evidence = curr_evidence if curr_evidence else "None",
                subs = sub_prompt
            )}
        ]

        assemble_result = await self.llm.call_llm(
            assemble_prompt,
            self.model,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            module_name=self.module_name
        )

        assemble_result["input_tokens"] += input_tokens
        assemble_result["output_tokens"] += output_tokens
        assemble_result["total_tokens"] += total_tokens

        return assemble_result
        

                
        