import sys
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
                   
from openai import OpenAI              
from utils.utils import (              
    clean_response_text,
    get_question_type_description,
    load_question_types_mapping
)
from prompt import data_filter_system, data_filter_user              
from task_requirements import task_description              


class QAFilter:
                     

    def __init__(self,
                 output_dir: str,
                 api_key: str,
                 base_url: str,
                 model_name: str = "gpt-4o",
                 question_types_map: Optional[Dict[str, Any]] = None):
    
        self.output_dir = Path(output_dir)
        self.model_name = model_name

                       
        if question_types_map:
            self.question_types_map = question_types_map
        else:
            self.question_types_map = load_question_types_mapping(self.output_dir)
        self.task_description = task_description

                      
        self.client = OpenAI(api_key=api_key, base_url=base_url)


    def get_related_documents(self, qa_item: Dict[str, Any]) -> str:
                            
        retrieved_docs = qa_item.get('retrieved_docs', [])
        doc_content = ""

        for i, doc in enumerate(retrieved_docs, 1):
                        
            content = doc.get('content', '')
            doc_content += f"Document {i}:\n{content}\n\n"

        return doc_content.strip()

    def evaluate_qa_quality(self, qa_item: Dict[str, Any]) -> int:
                                     
        try:
                    
            doc_content = self.get_related_documents(qa_item)

                      
            question_type = qa_item.get('question_type', '')
            question_type_desc = get_question_type_description(question_type, self.question_types_map)

                    
            task_type = qa_item.get('task_type', '')
            task_desc = self.task_description.get(task_type, '')

                        
            gen_data = [{
                "question": qa_item.get('question', ''),
                "answer": qa_item.get('answer', ''),
                "time": qa_item.get('time', ''),
                "relevant_passage": qa_item.get('references', [])
            }]

                                
            if isinstance(gen_data[0]["answer"], str):
                gen_data[0]["answer"] = [gen_data[0]["answer"]]

                    
            user_prompt = data_filter_user.format(
                doc_str=doc_content,
                topic_name=question_type_desc,
                task_name=task_type,
                task_require=task_desc,
                gen_datas=json.dumps(gen_data, ensure_ascii=False, indent=2)
            )


                       
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": data_filter_system},
                    {"role": "user", "content": user_prompt}
                ],
                max_tokens=1000,                       
                temperature=0.1
            )

            response_text = response.choices[0].message.content.strip()

            cleaned_text = clean_response_text(response_text)

            if not cleaned_text:
                cleaned_text = response_text

                                    
            eval_result = json.loads(cleaned_text)
            quality_score = eval_result.get('evaluation', 0)

            return quality_score

        except json.JSONDecodeError as e:
            return 0
        except Exception as e:
            return 0

    def filter_qa_pair(self, qa_item: Dict[str, Any]) -> Optional[Dict[str, Any]]:             
        quality_score = self.evaluate_qa_quality(qa_item)


                                       
        if quality_score >= 1:
            return qa_item
        else:
                             
            return None

    def filter_qa_batch(self, qa_pairs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:     
        filtered_qa_pairs = []
        stats = {'total': 0, 'qualified': 0, 'discarded': 0}

        for qa_item in qa_pairs:
            stats['total'] += 1

            filtered_qa = self.filter_qa_pair(qa_item)

            if filtered_qa:
                filtered_qa_pairs.append(filtered_qa)
                stats['qualified'] += 1
            else:
                stats['discarded'] += 1


        return filtered_qa_pairs
