import os
import json
from datetime import datetime
from typing import List, Dict, Any
import argparse
from pathlib import Path
from openai import OpenAI

from llama_index.llms.openai import OpenAI as LlamaOpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import Settings

from components.index_builder import IndexBuilder
from components.question_sampler import QuestionSampler
from components.qa_generator import QAGenerator
from components.role_matcher import RoleMatcher
from components.qa_inheritance_manager import QAInheritanceManager
from components.qa_filter import QAFilter


class GameQAGenerationSystem:
                    

    def __init__(self,
                 game_name: str,
                 segment_id: int,
                 api_key: str = "{your api_key}",
                 base_url: str = "{your base_url}",
                 model_name: str = "gpt-4o",
                 batch_size: int = 5,
                 target_sample_size: int = 150,
                 enable_role_playing: bool = True,
                 force_rebuild_index: bool = False,
                 similarity_threshold: float = 0.3,
                 enable_qa_inheritance: bool = True,
                 enable_qa_filtering: bool = True):
                       

        self.game_name = game_name
        self.segment_id = segment_id
        self.batch_size = batch_size
        self.model_name = model_name
        self.target_sample_size = target_sample_size
        self.enable_role_playing = enable_role_playing
        self.force_rebuild_index = force_rebuild_index
        self.similarity_threshold = similarity_threshold
        self.enable_qa_inheritance = enable_qa_inheritance
        self.enable_qa_filtering = enable_qa_filtering

                
        self.base_path = Path(".")
        self.corpus_path = self.base_path / "data" / game_name / "corpus"
        self.template_file = self.base_path / "data" / "merged_question_templates_dup_mov.jsonl"
        self.role_data_file = self.base_path / "data" / "merged_question_data_roles_dup_mov.jsonl"
        self.output_dir = self.base_path / "generation" / "data" / game_name
        self.output_dir.mkdir(parents=True, exist_ok=True)

              
        self._validate_paths()

                       
        self.segment_info = self._load_segment_info()
        question_types_file = self.base_path / "global_vars" / "question_types.json"
        self.question_types_map = self._load_question_types_mapping(question_types_file)

                 
        self._setup_api(api_key, base_url)

               
        self.index_builder = IndexBuilder(self.corpus_path, self.output_dir)
        self.question_sampler = QuestionSampler(self.template_file)

                    
        if self.enable_qa_inheritance:
                                   
            openai_client = OpenAI(api_key=api_key, base_url=base_url)
            self.inheritance_manager = QAInheritanceManager(self.base_path, self.game_name, openai_client)
        else:
            self.inheritance_manager = None

                  
        role_matcher = None
        if self.enable_role_playing and self.role_data_file.exists():
            try:
                embedding_client = OpenAI(api_key=api_key, base_url=base_url)
                role_index_dir = self.output_dir / "role_index"
                role_matcher = RoleMatcher(
                    role_data_file=str(self.role_data_file),
                    openai_client=embedding_client,
                    use_semantic_matching=True,
                    role_index_dir=str(role_index_dir)
                )
            except Exception as e:
                print(f"role matcher setup failed: {e}")

        self.qa_generator = QAGenerator(
            client=OpenAI(api_key=api_key, base_url=base_url),
            model_name=model_name,
            question_types_map=self.question_types_map,
            role_matcher=role_matcher
        )

                  
        if self.enable_qa_filtering:
            self.qa_filter = QAFilter(
                output_dir=str(self.output_dir),
                api_key=api_key,
                base_url=base_url,
                model_name=model_name,
                question_types_map=self.question_types_map
            )
        else:
            self.qa_filter = None

        self._print_init_info()

    def _filter_documents_by_similarity(
            self, documents: List[Dict[str, Any]], threshold: float = None) -> List[Dict[str, Any]]:
                       
        if threshold is None:
            threshold = self.similarity_threshold

        if not documents:
            return documents

                  
        filtered_docs = []
        for doc in documents:
            score = doc.get('score', 0.0)
            if score >= threshold:
                filtered_docs.append(doc)

                    
        filtered_docs.sort(key=lambda x: x.get('score', 0.0), reverse=True)

        if len(documents) > len(filtered_docs):
            filtered_count = len(documents) - len(filtered_docs)

        return filtered_docs

    def _apply_qa_inheritance(self, segment_id: int, inheritance_plan: Dict[str, Any]) -> int:
                      
        if not inheritance_plan:
            return 0

                
        segment_output_dir = self.output_dir / f"segment_{segment_id}"
        segment_output_dir.mkdir(exist_ok=True)

        qa_file = segment_output_dir / "generated_qa_pairs.jsonl"

                 

                  
        inherited_qas = inheritance_plan.get('inherited_qas', [])
        inherited_count = 0

        for qa in inherited_qas:
            qa['segment_id'] = segment_id
            qa['inherited_from'] = segment_id - 1
            qa['inheritance_status'] = 'inherited'
            self.append_qa_to_segment_file(qa, qa_file)
            inherited_count += 1

        if inherited_count > 0:
            print(f"applied {inherited_count} qa pairs from segment {segment_id - 1} to segment {segment_id}")

        return inherited_count

    def _validate_paths(self):
                         
        if not self.corpus_path.exists():
            raise FileNotFoundError(f"Corpus path does not exist: {self.corpus_path}")
        if not self.template_file.exists():
            raise FileNotFoundError(f"Question template file does not exist: {self.template_file}")

    def _load_segment_info(self) -> Dict[str, Any]:
                    
        overall_stats_file = self.corpus_path / "overall_stats.json"
        if not overall_stats_file.exists():
            raise FileNotFoundError(f"Segment information file does not exist: {overall_stats_file}")

        with open(overall_stats_file, 'r', encoding='utf-8') as f:
            return json.load(f)

    def _load_question_types_mapping(self, question_types_file: Path) -> Dict[str, Any]:
                        
        if not question_types_file.exists():
            exit(1)

        try:
            with open(question_types_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        except json.JSONDecodeError as e:
            exit(1)

    def _setup_api(self, api_key: str, base_url: str):
                     
        os.environ['OPENAI_API_KEY'] = api_key
        os.environ['OPENAI_API_BASE'] = base_url

        Settings.llm = LlamaOpenAI(
            api_key=api_key, base_url=base_url, model=self.model_name)
        Settings.embed_model = OpenAIEmbedding(
            api_key=api_key,
            api_base=base_url,
            model="text-embedding-3-small"
        )

    def _print_init_info(self):
        if self.enable_qa_inheritance:
            print(f"enable qa inheritance")
        if self.enable_qa_filtering:
            print(f"enable qa filtering")

    def get_segments_to_process(self) -> List[Dict[str, Any]]:
                         
        all_segments = [s for s in self.segment_info['segments'] if s['segment_id'] > 0]
        target_segments = [seg for seg in all_segments if seg['segment_id'] == self.segment_id]

        if not target_segments:
            raise ValueError(f"Segment ID not found: {self.segment_id}")
        return target_segments

    def save_segment_progress(self, processed_templates: List[str],
                              attempt_count: int, max_attempts: int, progress_file: Path,
                              generated_count: int = 0):
                    
        progress_data = {
            'processed_template_ids': processed_templates,
            'attempt_count': attempt_count,
            'max_attempts': max_attempts,
            'generated_count': generated_count,
            'last_updated': datetime.now().isoformat()
        }

        with open(progress_file, 'w', encoding='utf-8') as f:
            json.dump(progress_data, f, ensure_ascii=False, indent=2)

    def load_segment_progress(self, progress_file: Path) -> Dict[str, Any]:
                    
        if progress_file.exists():
            try:
                with open(progress_file, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except Exception as e:
                print(f"load segment progress failed: {e}")
        return {'processed_template_ids': [], 'attempt_count': 0, 'generated_count': 0}

    def append_qa_to_segment_file(self, qa_pair: Dict[str, Any], qa_file: Path):
                         
        with open(qa_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(qa_pair, ensure_ascii=False) + '\n')

    def generate_segment_qa_pairs(self, segment: Dict[str, Any]) -> int:
                        
        segment_id = segment['segment_id']

                  
        inheritance_plan = None
        inherited_count = 0
        inherited_qas_by_type = {}

        if self.enable_qa_inheritance and self.inheritance_manager and segment_id > 1:
            inheritance_plan = self.inheritance_manager.plan_qa_inheritance(
                segment_id, segment, self.target_sample_size)
            inherited_count = self._apply_qa_inheritance(segment_id, inheritance_plan)

                           
            inherited_qas = inheritance_plan.get('inherited_qas', [])
            for qa in inherited_qas:
                qa_type = qa.get('question_type', 'UNKNOWN')
                inherited_qas_by_type[qa_type] = inherited_qas_by_type.get(qa_type, 0) + 1


                  
        if not self.index_builder.build_segment_index(segment_id, self.force_rebuild_index):
            return 0

                               
        if self.enable_qa_inheritance and inheritance_plan:
            remaining_target = inheritance_plan.get('generation_needed', 0)
            generation_by_type = inheritance_plan.get('generation_by_type', {})
        else:
            remaining_target = max(0, self.target_sample_size - inherited_count)

        if remaining_target == 0:
            return inherited_count

                     
        segment_output_dir = self.output_dir / f"segment_{segment_id}"
        segment_output_dir.mkdir(exist_ok=True)

        qa_file = segment_output_dir / "generated_qa_pairs.jsonl"
        progress_file = segment_output_dir / "qa_generation_progress.json"

              
        progress = self.load_segment_progress(progress_file)
        processed_ids = set(progress.get('processed_template_ids', []))

                  
        start_generated_count = progress.get('generated_count', 0)
        start_attempt_count = progress.get('attempt_count', 0)

                 
        if not qa_file.exists():
            qa_file.touch()

        generated_count = start_generated_count
        batch_qa_pairs = []
        max_attempts = remaining_target * 3                 
        attempt_count = start_attempt_count

        if start_generated_count > 0:
                              
            remaining_target = max(0, remaining_target - start_generated_count)


        try:
            while (generated_count - start_generated_count) < remaining_target and attempt_count < max_attempts:
                                
                needed = remaining_target - (generated_count - start_generated_count)
                templates = self.question_sampler.sample_templates_by_distribution(
                    segment, min(10, needed), False, 0, inherited_qas_by_type)

                if not templates:
                    break

                for template_data in templates:
                    if (generated_count - start_generated_count) >= remaining_target:
                        break

                    attempt_count += 1
                    template_id = template_data.get('id', f't_{attempt_count}')
                    template_data["game_name"] = self.game_name

                              
                    if template_id in processed_ids:
                        continue

                    template_text = template_data.get('template', '')

                    try:
                                       
                        hypothetical_question, hypothetical_answer = self.qa_generator.generate_hypothetical_qa_from_template(
                            template_data
                        )

                        if not hypothetical_question or not hypothetical_answer:
                            processed_ids.add(template_id)
                            continue

                                
                        retrieved_docs = self.index_builder.retrieve_documents(
                            hypothetical_question + hypothetical_answer)

                                   
                        filtered_docs = self._filter_documents_by_similarity(retrieved_docs)

                        if not filtered_docs:
                            processed_ids.add(template_id)
                            continue

                                 
                        qa_pairs = self.qa_generator.generate_qa_pair(
                            template_data, filtered_docs
                        )

                                  
                        for qa in qa_pairs:
                            if (generated_count - start_generated_count) >= remaining_target:
                                break

                            if qa and qa.get('question') and qa.get('answer'):
                                qa['segment_id'] = segment_id
                                qa['game_name'] = self.game_name

                                                     
                                entities = qa.get('entities', [])
                                if entities:
                                    if self.inheritance_manager:
                                        qa['extracted_entities'] = entities

                                         
                                if self.enable_qa_filtering and self.qa_filter:
                                    filtered_qa = self.qa_filter.filter_qa_pair(qa)
                                    if filtered_qa:
                                        batch_qa_pairs.append(filtered_qa)
                                        generated_count += 1
                                    else:
                                        print(f"qa filtering failed for qa: {qa}")
                                else:
                                                
                                    batch_qa_pairs.append(qa)
                                    generated_count += 1
                            else:
                                print(f"qa is not valid: {qa}")
                        processed_ids.add(template_id)

                              
                        if len(batch_qa_pairs) >= self.batch_size:
                            for qa in batch_qa_pairs:
                                self.append_qa_to_segment_file(qa, qa_file)
                            batch_qa_pairs = []

                                
                        if attempt_count % 5 == 0:
                            self.save_segment_progress(
                                list(processed_ids), attempt_count, max_attempts, progress_file, generated_count)

                    except Exception as e:
                        processed_ids.add(template_id)
                        continue

                         
                if attempt_count % 10 == 0:
                    print(f"attempt count: {attempt_count}")
                      
            if batch_qa_pairs:
                for qa in batch_qa_pairs:
                    self.append_qa_to_segment_file(qa, qa_file)

                    
            self.save_segment_progress(list(processed_ids), attempt_count, max_attempts, progress_file, generated_count)

            total_qa_count = generated_count + inherited_count

            if (generated_count - start_generated_count) < remaining_target:
                print(f"generated count is less than remaining target: {generated_count - start_generated_count} < {remaining_target}")

        except KeyboardInterrupt:
            if batch_qa_pairs:
                for qa in batch_qa_pairs:
                    self.append_qa_to_segment_file(qa, qa_file)
            self.save_segment_progress(list(processed_ids), attempt_count, max_attempts, progress_file, generated_count)

        except Exception as e:
            if batch_qa_pairs:
                for qa in batch_qa_pairs:
                    self.append_qa_to_segment_file(qa, qa_file)
            self.save_segment_progress(list(processed_ids), attempt_count, max_attempts, progress_file, generated_count)

        return generated_count + inherited_count

    def generate_qa_pairs(self) -> int:
                       

                  
        segments_to_process = self.get_segments_to_process()
        total_generated = 0

                    
        for segment in segments_to_process:
            segment_generated = self.generate_segment_qa_pairs(segment)
            total_generated += segment_generated

        return total_generated

    def process(self) -> str:
                         
        generated_count = self.generate_qa_pairs()

        if generated_count > 0:
            return str(self.output_dir)
        else:
            return ""

    def reset_progress(self):
                  
        segments_to_process = self.get_segments_to_process()

        for segment in segments_to_process:
            segment_id = segment['segment_id']
            segment_output_dir = self.output_dir / f"segment_{segment_id}"

            if segment_output_dir.exists():
                for file_pattern in ["*generated_qa_pairs.jsonl", "*qa_generation_progress.json"]:
                    for file_path in segment_output_dir.glob(file_pattern):
                        file_path.unlink()
                self.index_builder.cleanup_index(segment_id)



def main():
             
    parser = argparse.ArgumentParser(description='Retrieval-Augmented QA Pair Generation')
    parser.add_argument('--game_name', type=str, default='dyinglight2', help='Game name')
    parser.add_argument('--segment_id', type=int, required=True, help='Time segment ID (1-n)')
    parser.add_argument('--batch_size', type=int, default=5, help='Batch processing size')
    parser.add_argument('--reset', action='store_true', help='Reset generation progress')
    parser.add_argument(
        '--api_key',
        default='{your api_key}',
        type=str,
        help='OpenAI API key')
    parser.add_argument('--base_url', default='{your base_url}', type=str, help='API base URL')
    parser.add_argument('--model_name', type=str, default='gpt-4o', help='Model name to use')

    parser.add_argument('--target_sample_size', type=int, default=150, help='Target sample size per segment')
    parser.add_argument('--disable_role_playing', action='store_true', help='Disable role playing functionality')
    parser.add_argument('--force_rebuild_index', action='store_true', help='Force rebuild index')
    parser.add_argument('--similarity_threshold', type=float, default=0.5, help='Document similarity filtering threshold')
    parser.add_argument('--disable_qa_inheritance', action='store_true', help='Disable QA inheritance functionality')
    parser.add_argument('--disable_qa_filtering', action='store_true', help='Disable QA filtering functionality')

    args = parser.parse_args()

    try:
        generator = GameQAGenerationSystem(
            game_name=args.game_name,
            segment_id=args.segment_id,
            api_key=args.api_key,
            base_url=args.base_url,
            model_name=args.model_name,
            batch_size=args.batch_size,
            target_sample_size=args.target_sample_size,
            enable_role_playing=not args.disable_role_playing,
            force_rebuild_index=args.force_rebuild_index,
            similarity_threshold=args.similarity_threshold,
            enable_qa_inheritance=not args.disable_qa_inheritance,
            enable_qa_filtering=not args.disable_qa_filtering
        )
    
        if args.reset:
            generator.reset_progress()

        output_file = generator.process()

        if output_file:
            print(f"output file: {output_file}")
        else:
            print(f"output file is empty")

    except Exception as e:
        print(f"error: {e}")


if __name__ == '__main__':
    main()
