from typing import Dict, Any, List
import asyncio
from datetime import datetime
import json
import os
from tqdm import tqdm
from src.core.llm import LLMBase
from src.core.utils import load_json
from src.modules.schema_linking.base import SchemaLinkerBase
from src.modules.sql_generation.base import SQLGeneratorBase
from src.modules.post_processing.base import PostProcessorBase
from src.core.intermediate import IntermediateResult
from src.core.logger import LoggerManager
from src.core.schema.manager import SchemaManager
from concurrent.futures import ThreadPoolExecutor
from src.core.config import Config

class ElephantSQLPipeline:

    def __init__(self,
                 schema_linker: SchemaLinkerBase,
                 sql_generator: SQLGeneratorBase,
                 post_processor: PostProcessorBase,
                 pipeline_id: str = datetime.now().strftime("%Y%m%d_%H%M%S")):
        
        self.pipeline_id = pipeline_id

        self.logger_manager = LoggerManager(self.pipeline_id)
        self.logger = self.logger_manager.get_logger("pipeline")

        self.schema_linker = schema_linker
        self.sql_generator = sql_generator
        self.post_processor = post_processor
        
        self.schema_linker.intermediate = IntermediateResult(self.schema_linker.name, self.pipeline_id)
        self.schema_linker.logger = self.logger_manager.get_logger(self.schema_linker.name)
        
        self.sql_generator.intermediate = IntermediateResult(self.sql_generator.name, self.pipeline_id)
        self.sql_generator.logger = self.logger_manager.get_logger(self.sql_generator.name)
        
        self.post_processor.intermediate = IntermediateResult(self.post_processor.name, self.pipeline_id)
        self.post_processor.logger = self.logger_manager.get_logger(self.post_processor.name)

        self.sql_generator.set_previous_module(self.schema_linker)
        if hasattr(self.sql_generator, 'generators'): 
            for generator in self.sql_generator.generators.values():
                generator.set_previous_module(self.schema_linker)
                
        self.post_processor.set_previous_module(self.sql_generator)

        self.intermediate = IntermediateResult("pipeline", self.pipeline_id)

        self.stats_logger = self.logger_manager.get_logger("api_statistics")
        
        self.schema_manager = SchemaManager()
        
    def prepare_queries(self, data_file: str) -> List[Dict]:

        merge_dev_demo_data = load_json(data_file)

        queries = []
        for item in merge_dev_demo_data:
            question_id = item.get("question_id", "")
            if not question_id:
                self.logger.warning("question lack ID")
                continue
                
            source = item.get("source", "")
            db_id = item.get("db_id", "")

            db_folder = f"{source}_{db_id}"  
            db_file = f"{db_id}.sqlite"      
            db_path = str(Config().database_dir / db_folder / db_file)
            
            try:
                schema = self.schema_manager.get_schema(db_id, db_path)
            except Exception as e:
                self.logger.error(f"load database schema fail: {db_path}")
                self.logger.error(f"error message: {str(e)}")
                continue
            
            queries.append({
                "query_id": question_id,
                "query": item.get("question"),
                "database_schema": schema.to_dict(), 
                "source": source
            })
            
        return queries
        
    async def run_pipeline_parallel(self, data_file: str, max_workers: int = 5) -> None:

        start_time = datetime.now()
        self.logger.info(f"begin to process data file: {data_file}")
        
        self.schema_linker.set_data_file(data_file)
        self.sql_generator.set_data_file(data_file)
        if hasattr(self.sql_generator, 'generators'): 
            for generator in self.sql_generator.generators.values():
                generator.set_data_file(data_file)
                
        self.post_processor.set_data_file(data_file)

        queries = self.prepare_queries(data_file)
        total_queries = len(queries)
        self.logger.info(f"loaded {total_queries} queries waiting for process")
        
        results = await self.process_batch_parallel(
            queries=queries,
            desc="Processing",
            max_workers=max_workers
        )
        
        end_time = datetime.now()
        duration = end_time - start_time
        hours = duration.seconds // 3600
        minutes = (duration.seconds % 3600) // 60
        seconds = duration.seconds % 60

        self.logger.info("="*50)
        self.logger.info("Pipeline run completed")
        self.logger.info(f"total queries: {total_queries}")
        self.logger.info(f"Maximum parallel number: {max_workers}")
        self.logger.info(f"successfully process: {len(results)}, failed num: {total_queries - len(results)}, successful rate: {(len(results) / total_queries * 100):.2f}%")
        self.logger.info(f"total cost time: {hours} Hours {minutes} Minutes {seconds} Seconds, Average time per query: {duration.total_seconds() / total_queries:.2f} Seconds")
        self.logger.info("="*50)
        self.log_api_stats()
    
    async def run_pipeline(self, data_file: str) -> None:

        start_time = datetime.now()
        self.logger.info(f"begin to process data file: {data_file}")

        self.sql_generator.set_data_file(data_file)

        queries = self.prepare_queries(data_file)
        total_queries = len(queries)
        self.logger.info(f"loaded {total_queries} queries waiting process")

        results = await self.process_batch(
            queries=queries,
            desc="Processing"
        )

        end_time = datetime.now()
        duration = end_time - start_time
        hours = duration.seconds // 3600
        minutes = (duration.seconds % 3600) // 60
        seconds = duration.seconds % 60

        self.logger.info("="*50)
        self.logger.info("Pipeline run completed")
        self.logger.info(f"total queries: {total_queries}")
        self.logger.info(f"successfully process: {len(results)}, failed num: {total_queries - len(results)}, successful rate: {(len(results) / total_queries * 100):.2f}%")
        self.logger.info(f"total cost time: {hours} Hours {minutes} Minutes {seconds} Seconds, Average time per query: {duration.total_seconds() / total_queries:.2f} Seconds")
        self.logger.info("="*50)
        self.log_api_stats()
        
    def log_api_stats(self):
        stats_file = os.path.join(self.intermediate.pipeline_dir, "api_stats.json")
        if os.path.exists(stats_file):
            with open(stats_file, 'r', encoding='utf-8') as f:
                stats = json.load(f)

            self.logger.warning(f"LLM API total call times: {stats['total_calls']}")
            self.logger.warning(f"total input tokens: {stats['total_cost']['input_tokens']}; total output tokens: {stats['total_cost']['output_tokens']}")
            

            self.stats_logger.info("="*50)
            self.stats_logger.info("Detailed API call statistics")
            self.stats_logger.info("="*50)
            
            self.stats_logger.info("Model statistics:")
            for model, model_stats in stats['models'].items():
                self.stats_logger.info(f"{model}:")
                self.stats_logger.info(f"  call times: {model_stats['calls']}")
                self.stats_logger.info(f"  input tokens: {model_stats['input_tokens']}")
                self.stats_logger.info(f"  output tokens: {model_stats['output_tokens']}")
                self.stats_logger.info(f"  total tokens: {model_stats['total_tokens']}")
            
            self.stats_logger.info("Statistics by module:")
            for module, module_stats in stats['modules'].items():
                self.stats_logger.info(f"{module}:")
                self.stats_logger.info(f"  total call times: {module_stats['calls']}")
                for model, model_stats in module_stats['models'].items():
                    self.stats_logger.info(f"  {model}:")
                    self.stats_logger.info(f"    call times: {model_stats['calls']}")
                    self.stats_logger.info(f"   total tokens: {model_stats['total_tokens']}")
            self.stats_logger.info("="*50)
        
    async def process(self, 
                    query: str, 
                    database_schema: Dict,
                    query_id: str,
                    source: str = "") -> Dict[str, Any]:

        database_schema["source"] = source

        for module in [self.schema_linker, self.sql_generator, self.post_processor]:
            module.logger.info("="*70)
            module.logger.info(f"Processing Query ID: {query_id}")
            module.logger.info(f"Query: {query}")

        enriched_linked_schema = await self.schema_linker.link_schema_with_retry(
            query, 
            database_schema,
            query_id=query_id
        )
        
        schema_linking_output = {
            "original_schema": database_schema,
            "linked_schema": enriched_linked_schema
        }

        extracted_sql = await self.sql_generator.generate_sql_with_retry(
            query,
            schema_linking_output,
            query_id=query_id
        )

        processed_sql = await self.post_processor.process_sql_with_retry(
            extracted_sql,
            query_id=query_id
        )

        self.intermediate.save_sql_result(query_id, source, processed_sql)
        
        return {
            "query_id": query_id,
            "query": query,
            "extracted_schema": enriched_linked_schema,
            "extracted_sql": extracted_sql,
            "processed_sql": processed_sql
        } 
    
    async def process_batch(self, 
                         queries: List[Dict],
                         desc: str = "Processing") -> List[Dict]:
        results = []
        with tqdm(total=len(queries), desc=desc) as pbar:
            for query_item in queries:
                result = await self.process(
                    query=query_item["query"],
                    database_schema=query_item["database_schema"],
                    query_id=query_item["query_id"],
                    source=query_item.get("source", "")
                )
                results.append(result)
                pbar.update(1)
        return results
    
    async def process_batch_parallel(self, queries: List[Dict], desc: str = "", max_workers: int = 5) -> List[Dict]:

        results = []
        completed = 0
        total = len(queries)

        semaphore = asyncio.Semaphore(max_workers)

        self.logger.info(f"Start processing queries parallel {total}, Maximum parallel number: {max_workers}")
        
        async def process_with_semaphore(query):

            async with semaphore:
                nonlocal completed
                try:
                    self.logger.info(
                        f"Start processing queries [{query['query_id']}] "
                        f"({completed + 1}/{total}, Current number of parallel tasks: {len(asyncio.all_tasks()) - 1})"
                    )
                    
                    result = await self.process(
                        query=query["query"],
                        database_schema=query["database_schema"],
                        query_id=query["query_id"],
                        source=query["source"]
                    )
                    
                    completed += 1
                    self.logger.info(
                        f"Complete Query [{query['query_id']}] "
                        f"({completed}/{total}, remain: {total - completed})"
                    )
                    return result
                except Exception as e:
                    completed += 1
                    self.logger.error(
                        f"processing query [{query['query_id']}] error occur: {str(e)} "
                        f"({completed}/{total}, remain: {total - completed})"
                    )
                    return None

        tasks = [process_with_semaphore(query) for query in queries]

        with tqdm(total=total, desc=f"Processing (threads: {max_workers})") as pbar:
            for task in asyncio.as_completed(tasks):
                try:
                    result = await task
                    if result is not None:
                        results.append(result)
                    pbar.update(1)
                    pbar.set_postfix({
                        'remain': total - completed,
                        'fail': completed - len(results)
                    })
                except Exception as e:
                    self.logger.error(f"Task execution exception: {str(e)}")
                    pbar.update(1)
                    continue

        success_rate = (len(results) / total) * 100
        self.logger.info(
            f"Parallel processing is complete.total num: {total}, success: {len(results)}, "
            f"failed: {total - len(results)}, success rate: {success_rate:.2f}%"
        )
        
        return results 