import asyncio
from typing import Dict, List, Optional
import json
from pathlib import Path
from enum import Enum
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from ..core.llm import LLMBase
from ..modules.schema_linking.enhanced_linker import EnhancedSchemaLinker
from ..modules.sql_generation.gpt_generator import GPTSQLGenerator
from ..modules.sql_generation.dc_refiner_generator import DCRefinerSQLGenerator
from ..modules.sql_generation.enhanced_generator import EnhancedSQLGenerator
from ..modules.post_processing.skip_post_processing import SkipPostProcessor
from ..pipeline import ElephantSQLPipeline
from ..evaluation.compute_ex import compare_sql_results
from ..core.utils import load_json, TextExtractor
from ..pipeline_factory import PipelineFactory, PipelineLevel
from ..core.config import Config
from ..core.logger import LoggerManager

class PipelineType(Enum):
    BASIC = 1          
    INTERMEDIATE = 2   
    ADVANCED = 3     
    UNSOLVED = 4       

class Labeler:

    def __init__(self, llm: LLMBase):
        self.llm = llm
        self.extractor = TextExtractor()
        self.pipeline_factory = PipelineFactory(llm, backbone_model="gpt-3.5-turbo", temperature=0.0, max_retries=10)
        
        self.basic_pipeline = self.pipeline_factory.get_pipeline(PipelineLevel.BASIC)
        self.intermediate_pipeline = self.pipeline_factory.get_pipeline(PipelineLevel.INTERMEDIATE)
        self.advanced_pipeline = self.pipeline_factory.get_pipeline(PipelineLevel.ADVANCED)

        self.logger_manager = LoggerManager()
        self.logger = self.logger_manager.get_logger("labeler")
        self.stats_logger = self.logger_manager.get_logger("labeling_stats")

    async def label_single_item(self, item: Dict, data_file: str) -> Dict:
        query_id = item["question_id"]
        source = item["source"]
        db_id = item["db_id"]
        gold_sql = item["gold_SQL"]
        db_folder = f"{source}_{db_id}"
        db_file = f"{db_id}.sqlite"
        db_path = str(Config().database_dir / db_folder / db_file)
        
        self.logger.info(f"begin to process example [ID: {query_id}] database: {db_path}, question: {item['question']}")
        
        try:
            linked_schema = await self.basic_pipeline.schema_linker.link_schema_with_retry(
                query=item["question"],
                database_schema=self.basic_pipeline.schema_manager.get_schema(db_id, db_path).to_dict(),
                query_id=query_id
            )
            
            if linked_schema is None:
                self.logger.error(f"[ID: {query_id}] Schema linking failed")
                return None
            
            enhanced_linked_schema_wo_info = self.basic_pipeline.schema_linker.enhance_schema_only_with_keys(
                linked_schema,  
                self.basic_pipeline.schema_manager.get_schema(db_id, db_path).to_dict()
            )

            self.logger.debug(f"[ID: {query_id}] try BASIC Pipeline...")
            result = await self.basic_pipeline.process(
                query=item["question"],
                database_schema=self.basic_pipeline.schema_manager.get_schema(db_id, db_path).to_dict(),
                query_id=query_id,
                source=source
            )
            if result:
                processed_sql = result.get("processed_sql", "")
                check_sql_result = compare_sql_results(db_path, gold_sql, processed_sql)
                is_correct = check_sql_result[0] if isinstance(check_sql_result, tuple) else check_sql_result
                
                if is_correct:
                    self.logger.info(f"[ID: {query_id}] BASIC Pipeline successfully processed")
                    label = PipelineType.BASIC.value
                    return self._create_result_dict(item, label, enhanced_linked_schema_wo_info)

            self.logger.debug(f"[ID: {query_id}] BASIC failed, try INTERMEDIATE Pipeline...")
            result = await self.intermediate_pipeline.process(
                query=item["question"],
                database_schema=self.intermediate_pipeline.schema_manager.get_schema(db_id, db_path).to_dict(),
                query_id=query_id,
                source=source
            )
            if result:
                processed_sql = result.get("processed_sql", "")
                check_sql_result = compare_sql_results(db_path, gold_sql, processed_sql)
                is_correct = check_sql_result[0] if isinstance(check_sql_result, tuple) else check_sql_result
                
                if is_correct:
                    self.logger.info(f"[ID: {query_id}] INTERMEDIATE Pipeline successfully processed")
                    label = PipelineType.INTERMEDIATE.value
                    return self._create_result_dict(item, label, enhanced_linked_schema_wo_info)

            self.logger.debug(f"[ID: {query_id}] INTERMEDIATE failed, try ADVANCED Pipeline...")
            result = await self.advanced_pipeline.process(
                query=item["question"],
                database_schema=self.advanced_pipeline.schema_manager.get_schema(db_id, db_path).to_dict(),
                query_id=query_id,
                source=source
            )
            if result:
                processed_sql = result.get("processed_sql", "")
                check_sql_result = compare_sql_results(db_path, gold_sql, processed_sql)
                is_correct = check_sql_result[0] if isinstance(check_sql_result, tuple) else check_sql_result
                
                if is_correct:
                    self.logger.info(f"[ID: {query_id}] ADVANCED Pipeline successfully processed")
                    label = PipelineType.ADVANCED.value
                    return self._create_result_dict(item, label, enhanced_linked_schema_wo_info)

            self.logger.warning(f"[ID: {query_id}] all Pipeline failed")
            label = PipelineType.UNSOLVED.value
            return self._create_result_dict(item, label, enhanced_linked_schema_wo_info)
            
        except Exception as e:
            self.logger.error(f"[ID: {query_id}] error occured when processing example: {str(e)}")
            return None

    def _create_result_dict(self, item: Dict, label: int, enhanced_schema: Dict) -> Dict:
        return {
            "question_id": item["question_id"],
            "source": item["source"],
            "db_id": item["db_id"],
            "question": item.get("question", ""),
            "difficulty": item.get("difficulty", ""),
            "gold_sql": item["gold_SQL"],
            "pipeline_type": PipelineType(label).name,
            "label": label,
            "enhanced_linked_schema_wo_info": enhanced_schema
        }

    async def label_dataset_parallel(self, data_file: str, output_file: str, max_workers: int = 5):
        data_file = str(Config().data_dir / Path(data_file))
        output_file = str(Config().data_dir / Path(output_file))
        
        self.logger.info(f"begin to process data file: {data_file}")
        self.logger.info(f"output file: {output_file}")
        
        dataset = load_json(data_file)
        labeled_results = []
        
        output_path = Path(output_file)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        for pipeline in [self.basic_pipeline, self.intermediate_pipeline, self.advanced_pipeline]:
            pipeline.schema_linker.set_data_file(data_file)
            pipeline.sql_generator.set_data_file(data_file)
            pipeline.post_processor.set_data_file(data_file)

        tasks = []
        semaphore = asyncio.Semaphore(max_workers)
        
        async def process_with_semaphore(item):
            async with semaphore:
                return await self.label_single_item(item, data_file)
        
        for item in dataset:
            tasks.append(process_with_semaphore(item))
        
        with tqdm(total=len(tasks), desc=f"Labeling dataset (threads: {max_workers})") as pbar:
            for coro in asyncio.as_completed(tasks):
                result = await coro
                if result:
                    labeled_results.append(result)
                    with open(output_file, "a", encoding="utf-8") as f:
                        f.write(json.dumps(result, ensure_ascii=False) + "\n")
                pbar.update(1)

        self._log_labeling_stats(labeled_results)
        return labeled_results

    def _log_labeling_stats(self, results: List[Dict]):
        total = len(results)
        label_counts = {label.name: 0 for label in PipelineType}
        source_stats = {}
        difficulty_stats = {}
        
        for result in results:
            pipeline_type = result["pipeline_type"]
            label_counts[pipeline_type] += 1
            
            source = result["source"]
            if source not in source_stats:
                source_stats[source] = {label.name: 0 for label in PipelineType}
            source_stats[source][pipeline_type] += 1

            if "difficulty" in result:
                difficulty = result["difficulty"]
                if difficulty not in difficulty_stats:
                    difficulty_stats[difficulty] = {label.name: 0 for label in PipelineType}
                difficulty_stats[difficulty][pipeline_type] += 1

        stats_lines = [
            "="*50,
            "Annotated statistics",
            "="*50,
            f"total sample num: {total}",
            "\nlabel distribution:"
        ]
        
        for label, count in label_counts.items():
            percentage = (count / total) * 100
            stats_lines.append(f"{label}: {count} ({percentage:.2f}%)")
            
        stats_lines.append("\nStatistics by source:")
        for source, stats in source_stats.items():
            stats_lines.append(f"\n{source}:")
            source_total = sum(stats.values())
            for label, count in stats.items():
                percentage = (count / source_total) * 100
                stats_lines.append(f"  {label}: {count} ({percentage:.2f}%)")
                
        if difficulty_stats:
            stats_lines.append("\nStatistics by difficulty:")
            for difficulty, stats in difficulty_stats.items():
                stats_lines.append(f"\n{difficulty}:")
                diff_total = sum(stats.values())
                for label, count in stats.items():
                    percentage = (count / diff_total) * 100
                    stats_lines.append(f"  {label}: {count} ({percentage:.2f}%)")
        
        stats_lines.append("="*50)

        stats_text = "\n".join(stats_lines)
        self.stats_logger.info(stats_text)
        print("\n" + stats_text)

async def main():

    llm = LLMBase()

    labeler = Labeler(llm)

    await labeler.label_dataset_parallel(
        data_file="formatted_bird_dev.json",
        output_file="labeled/bird_dev_pipeline_label.jsonl",
        max_workers=100
    )

if __name__ == "__main__":

    policy = asyncio.get_event_loop_policy()
    policy.set_event_loop(policy.new_event_loop())

    loop = asyncio.get_event_loop()
    loop.set_default_executor(ThreadPoolExecutor(max_workers=300))

    loop.run_until_complete(main())
    loop.close()
