import os
from typing import Optional,Union,  List, Dict, Any
from pydantic import BaseModel, ConfigDict, Field
from dotenv import find_dotenv, load_dotenv
import time
from mcp import StdioServerParameters # For Stdio Server
import asyncio
import sys
from pathlib import Path
import random
sys.path.append(str(Path(__file__).resolve().parent))
from utils import get_schema_from_sqllite, get_schema_from_file, async_execute_sql, evaluate_execution_accuracy2
from crewai.tools import tool
from agentics import AG
from db import DB

from crewai_tools import MCPServerAdapter
from agentics import AG

load_dotenv(find_dotenv())



class AnswerAssessment(BaseModel):
    question:Optional[str]
    sql:Optional[str]
    output_dataframe:Optional[str]
    answer_quality_score:Optional[float]=Field(None,ge=0,le=1,description="1 if the answer is correct and you are certain about, 0 if the answer if likely wrong or resulted in an error. Return a number between 0 and 1 when you are uncertain")
    answer_quality_assessment:Optional[str]=Field(None,description="Explain the rationale why you provide your grade to the answer.")


class Text2sqlQuestion(BaseModel): 
    question: Optional[str] =None
    db_id: Optional[str] =None
    benchmark_id: Optional[str] =None

    sql: Optional[Union[str,list[str]]] =None
    ddl: Optional[str] =None
    query: Optional[str] =None
    evidence: Optional[Union[str,list[str]]]=None
    reasoning_type: Optional[str] =None
    commonsense_knowledge: Optional[str] =None
    #schema: Optional[str] = None
    db:Optional[DB] = None
    alternative_sql_queries: Optional[List[str]] = Field(
        None, 
        description="""Generate alternative 5 SQL queries that have higher chances to 
        provide the correct answer to the question. Make sure those are diversified 
        to cover multiple different strategies to get the right answer""")
    alternative_answer_assessments:Optional[List[AnswerAssessment]] = []
    endpoint_id:Optional[str] =None
    generated_query: Optional[str] =Field(None, description="The query generated by AI")
    answer_assessment: Optional[AnswerAssessment] = Field(None, description="An assessment on the quality of the generated answer")
    system_output_df: Optional[str] = None
    gt_output_df: Optional[str] = None


async def get_schema(state:Text2sqlQuestion)-> Text2sqlQuestion:
    schema_path = os.path.join(os.getenv("SQL_DB_PATH"), state.benchmark_id ,
                                state.db_id,state.db_id+".sqlite" )
    state.db = DB(db_id=state.db_id, benchmark_id=state.benchmark_id, db_path=schema_path)
    state.db.get_schema_from_sqllite()
    state.ddl=state.db.ddl
    return state

async def enrich_all_dbs(test:AG):
    dbs=set()
    filtered_test=AG(atype=Text2sqlQuestion)
    for question in test:
        #print(question.db_id)
        if question.db_id not in dbs:
            filtered_test.states.append(question)
            dbs.add(question.db_id)

    await filtered_test.amap(load_db)


async def load_db(state:Text2sqlQuestion)-> Text2sqlQuestion:
    if not state.db:
        state = await get_schema(state)
        state.ddl=state.db.db_schema.model_dump_json()
    return state
async def enrich_db(state:Text2sqlQuestion)-> Text2sqlQuestion:
    state.db= await state.db.load_enrichments()
    state.ddl = state.db.db_schema.model_dump_json()
    return state


@tool("execute_sql_query")
async def execute_sql_query(sql_query:str, db_id:str)-> str:
    """Execute a SQL query against the target db and return the execution results (error or json dataframe)"""
    schema_path = os.path.join(os.getenv("SQL_DB_PATH"), 
                            db_id,db_id+".sqlite" )
    system_output_df= await async_execute_sql(sql_query, schema_path)
    return system_output_df 


async def execute_query_map(state:Text2sqlQuestion)-> Text2sqlQuestion:
    schema_path=None
    if not state.endpoint_id:
        schema_path = os.path.join(os.getenv("SQL_DB_PATH"), 
                                state.benchmark_id,state.db_id, state.db_id+".sqlite" )
    
    state.system_output_df= await async_execute_sql(state.generated_query, db_path = schema_path, endpoint_id=state.endpoint_id)
    state.gt_output_df= await async_execute_sql(state.query or (state.sql[0] if type(state.sql)==list else state.sql), db_path = schema_path, endpoint_id=state.endpoint_id)
    return state

def get_training_data(training_dataset:str, n_shots=3) -> AG:
    training = AG.from_jsonl(training_dataset,jsonl=False)
    training = training.rebind_atype(Text2sqlQuestion)
    few_shots= AG(atype=Text2sqlQuestion)
    for i in range(n_shots):
        selected_question = random.choice(training.states)
        selected_question.generated_query = selected_question.sql
        few_shots.states.append(selected_question)
    return few_shots


async def execute_questions(test:AG, 
                            few_shots_path:str = None,
                            answer_validation: bool = True,
                            enrichments:bool=True, 
                            multiple_runs:int = 1,
                            save_run_path:str=None):
    save_test=test.clone()
    total_accuracy = 0
    for run in range(multiple_runs):
        test=save_test
        begin_time=time.time()
        training = AG(atype=Text2sqlQuestion)

        if few_shots_path:
            training = get_training_data(few_shots_path)
        test.reasoning=False
        test.states= training.states+test.states        
    
        test= await test.amap(load_db)
        if enrichments:
            test=await test.amap(enrich_db)
        test = await baseline_zero_shot(test)

        if answer_validation == True:
            test = await perform_answer_validation(test)
        
        if save_run_path:
            if not os.path.exists(save_run_path): os.mkdir(save_run_path)
            output_file = os.path.join(save_run_path, f"exp_{run}.jsonl")
            test.to_jsonl(output_file)
           
            experiment_evaluation_output=os.path.join(save_run_path, f"exp_{run}_eval.txt")
        
        print(f"task executed in {time.time() - begin_time} seconds")
        test.states=test.states[len(training.states):]
       
        accuracy, full_eval = evaluate_execution_accuracy2(test)
        total_accuracy +=accuracy
        if save_run_path:
            experiment_evaluation_output=os.path.join(save_run_path, f"exp_{run}_eval.txt")
            with open(experiment_evaluation_output,"w") as f:
                f.write(full_eval+"\n")
        
    print(f"Average execution accuracy: {total_accuracy/multiple_runs}")
    return test , total_accuracy/multiple_runs




async def baseline_zero_shot(test:AG)-> AG:

    test = await test.self_transduction(
        ["question","db_id", "ddl", "commonsense_knowledge"], 
        ["generated_query"], 
        instructions=
            "Your task is to convert a natural language question into an accurate SQL query using the given the database schema.\n\n"
            "**Instructions:**\n"
            "- Only use columns listed in the schema.\n"
            "- Do not use any other columns or tables not mentioned in the schema.\n"
            "- Ensure the SQL query is valid and executable.\n"
            "- Use proper SQL syntax and conventions.\n"
            "- Generate a complete SQL query that answers the question.\n"
            "- Use the correct SQL dialect for SQLite \n"
            "- Do not include any explanations or comments in the SQL output.\n"
    )
    test= await test.amap(execute_query_map)
    return test


async def run_evaluation_benchmark(benchmark_id = "archer_en_dev",
                       max_rows: int = None,
                       few_shots_path: str = None):
    
    benchmarks= os.listdir(os.getenv("SQL_BENCHMARKS_FOLDER"))
    if benchmark_id + ".json" in benchmarks:
        test = AG.from_jsonl(
                    os.path.join(os.getenv("SQL_BENCHMARKS_FOLDER"), 
                                          benchmark_id + ".json"),     
                    jsonl=False,
                    atype=Text2sqlQuestion,
                    max_rows=max_rows)
       
        new_states=[]
        for state in test:
            state.benchmark_id = benchmark_id
            new_states.append(state)
        test.states=new_states

        test = await execute_questions(test, few_shots_path=few_shots_path, save_run_path="")
        print(evaluate_execution_accuracy2(test))
        return test
        

