import json
import hashlib
from pathlib import Path
from typing import List, Dict, Any, Set
from openai import OpenAI


class QAInheritanceManager:
                  

    def __init__(self, base_path: Path, game_name: str, openai_client: OpenAI = None):
        self.base_path = base_path
        self.game_name = game_name
        self.corpus_path = base_path / "data" / game_name / "corpus"
        self.qa_data_path = base_path / "generation" / "data" / game_name
        self.openai_client = openai_client
                       
        self._segment_entities_cache = {}

    def load_qa_pairs_from_segment(self, segment_id: int) -> List[Dict[str, Any]]:
                        
        qa_pairs = []
        qa_file = self.qa_data_path / f"segment_{segment_id}" / "generated_qa_pairs.jsonl"

        if not qa_file.exists():
            return qa_pairs

        with open(qa_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        qa_pair = json.loads(line.strip())
                        qa_pairs.append(qa_pair)
                    except json.JSONDecodeError:
                        continue

        return qa_pairs

    def extract_and_save_qa_entities(self, qa_pair: Dict[str, Any]) -> Dict[str, Any]:                                       
        entities_from_corpus = qa_pair.get('entities', [])
        if entities_from_corpus:
                          
            entity_texts = [entity.get('text', '') for entity in entities_from_corpus if entity.get('text')]
            qa_pair['extracted_entities'] = entity_texts
        else:
            qa_pair['extracted_entities'] = []

                   
        question = qa_pair.get('question', '')
        answer = qa_pair.get('answer', '')
        qa_text = f"Question: {question}\nAnswer: {answer}"
        qa_pair['entity_extraction_timestamp'] = json.dumps(
            {"timestamp": hashlib.md5(qa_text.encode()).hexdigest()[:8]}
        )

        return qa_pair

    def _extract_new_segment_entities_from_corpus(self, segment_id: int) -> Set[str]:
           
        all_entities = set()
        corpus_file = self.corpus_path / f"segment_{segment_id}" / "corpus.jsonl"

        if not corpus_file.exists():
            return all_entities

        with open(corpus_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line.strip())
                                              
                        metadata = data.get('metadata', {})
                        entities_list = metadata.get('entities', [])

                                
                        for entity in entities_list:
                            if isinstance(entity, dict) and 'text' in entity:
                                entity_text = entity['text'].strip()
                                if entity_text:
                                    all_entities.add(entity_text)

                    except json.JSONDecodeError:
                        continue

        return all_entities

    def get_segment_entities(self, segment_id: int) -> Set[str]:
           
              
        if segment_id in self._segment_entities_cache:
            return self._segment_entities_cache[segment_id]

                           
        entities = self._extract_new_segment_entities_from_corpus(segment_id)

              
        self._segment_entities_cache[segment_id] = entities

        return entities

    def clear_segment_entities_cache(self, segment_id: int = None):
           
        if segment_id is None:
            self._segment_entities_cache.clear()
        else:
            if segment_id in self._segment_entities_cache:
                del self._segment_entities_cache[segment_id]

    def _check_qa_entity_overlap(self, qa_pair: Dict[str, Any],
                                 current_segment_entities: Set[str]) -> bool:
           
                    
        qa_entities = set()
        if 'extracted_entities' in qa_pair:
                            
            entity_list = qa_pair['extracted_entities']
            if isinstance(entity_list, list):
                qa_entities = set(entity.get('text', '') for entity in entity_list if isinstance(entity, dict))

                                     
        if not qa_entities or not current_segment_entities:
            return False

                    
        intersection = qa_entities.intersection(current_segment_entities)
        return len(intersection) > 0

    def plan_qa_inheritance(self, current_segment_id: int,
                            current_segment_info: Dict[str, Any],
                            target_sample_size: int,
                            previous_segment_id: int = None) -> Dict[str, Any]:
           
        if previous_segment_id is None:
            previous_segment_id = current_segment_id - 1

        if previous_segment_id < 1:
            return {
                'inherit_count': 0,
                'discarded_count': 0,
                'inherited_qas': [],
                'previous_segment_total': 0
            }

                                      
        current_segment_entities = self.get_segment_entities(current_segment_id)

                     
        previous_qas = self.load_qa_pairs_from_segment(previous_segment_id)


                                
        candidate_qas = []
        discarded_count = 0

        for i, qa in enumerate(previous_qas):
                                  
            has_overlap = self._check_qa_entity_overlap(qa, current_segment_entities)

            if has_overlap:
                               
                discarded_count += 1
            else:
                              
                candidate_qas.append(qa)

                  
            if (i + 1) % 50 == 0:
                print(f"candidate qas count: {len(candidate_qas)}")
        print(f"candidate qas count: {len(candidate_qas)}")
                                            
        inheritance_plan = self._calculate_inheritance_and_generation_plan(
            candidate_qas,
            current_segment_info['type_distribution'],
            target_sample_size
        )

        return {
            'inherit_count': inheritance_plan['total_inherit'],
            'discarded_count': len(previous_qas) - inheritance_plan['total_inherit'],
            'inherited_qas': inheritance_plan['inherited_qas'],
            'previous_segment_total': len(previous_qas),
            'candidate_count': len(candidate_qas),
            'generation_needed': inheritance_plan['generation_needed'],
            'inheritance_by_type': inheritance_plan['inheritance_by_type'],
            'generation_by_type': inheritance_plan['generation_by_type']
        }

    def _calculate_inheritance_and_generation_plan(self, candidate_qas: List[Dict[str, Any]],
                                                   original_type_distribution: Dict[str, int],
                                                   target_sample_size: int) -> Dict[str, Any]:
           
                                                  
        original_total = sum(original_type_distribution.values())
        if original_total == 0:
            return {
                'inherited_qas': [],
                'total_inherit': 0,
                'generation_needed': target_sample_size,
                'inheritance_by_type': {},
                'generation_by_type': {'UNKNOWN': target_sample_size}
            }

                  
        type_proportions = {qtype: count / original_total
                            for qtype, count in original_type_distribution.items()}

                                               
        target_distribution = {}
        total_allocated = 0

                
        for qtype, proportion in type_proportions.items():
            allocated = int(target_sample_size * proportion)
            target_distribution[qtype] = allocated
            total_allocated += allocated

                                 
        remaining = target_sample_size - total_allocated
        if remaining > 0:
                       
            max_type = max(type_proportions.keys(), key=lambda x: type_proportions[x])
            target_distribution[max_type] += remaining

        for qtype, count in sorted(target_distribution.items(), key=lambda x: x[1], reverse=True):
            proportion = type_proportions[qtype]

        if not candidate_qas:
                               
            return {
                'inherited_qas': [],
                'total_inherit': 0,
                'generation_needed': sum(target_distribution.values()),
                'inheritance_by_type': {qtype: 0 for qtype in target_distribution.keys()},
                'generation_by_type': dict(target_distribution)
            }

                    
        candidate_qas_by_type = {}
        for qa in candidate_qas:
            qa_type = qa.get('question_type', 'UNKNOWN')
            if qa_type not in candidate_qas_by_type:
                candidate_qas_by_type[qa_type] = []
            candidate_qas_by_type[qa_type].append(qa)

        for qa_type, qas in candidate_qas_by_type.items():
            print(f"candidate qas by type: {qa_type} {len(qas)}")

        for qtype, count in sorted(target_distribution.items(), key=lambda x: x[1], reverse=True):
            print(f"target distribution: {qtype} {count}")

                        
        inheritance_by_type = {}
        generation_by_type = {}
        selected_qas = []

        import random

        for qtype, target_count in target_distribution.items():
            available_count = len(candidate_qas_by_type.get(qtype, []))

            if target_count == 0:
                                
                inherit_count = 0
                generate_count = 0
            elif available_count >= target_count:
                                          
                inherit_count = target_count
                generate_count = 0

                      
                if qtype in candidate_qas_by_type and target_count > 0:
                    available_qas = candidate_qas_by_type[qtype]
                    sampled_qas = random.sample(available_qas, target_count)
                    selected_qas.extend(sampled_qas)

            else:
                                         
                inherit_count = available_count
                generate_count = target_count - available_count

                      
                if qtype in candidate_qas_by_type:
                    selected_qas.extend(candidate_qas_by_type[qtype])

            inheritance_by_type[qtype] = inherit_count
            generation_by_type[qtype] = generate_count


        total_inherit = sum(inheritance_by_type.values())
        total_generate = sum(generation_by_type.values())


        return {
            'inherited_qas': selected_qas,
            'total_inherit': total_inherit,
            'generation_needed': total_generate,
            'inheritance_by_type': inheritance_by_type,
            'generation_by_type': generation_by_type
        }

    def _resample_qas_by_distribution(self, candidate_qas: List[Dict[str, Any]],
                                      target_distribution: Dict[str, int],
                                      target_sample_size: int) -> List[Dict[str, Any]]:
           
        if not candidate_qas:
            return []

                    
        qas_by_type = {}
        for qa in candidate_qas:
            qa_type = qa.get('question_type', 'UNKNOWN')
            if qa_type not in qas_by_type:
                qas_by_type[qa_type] = []
            qas_by_type[qa_type].append(qa)

        for qa_type, qas in qas_by_type.items():
            print(f"qas by type: {qa_type} {len(qas)}")
                   
        total_target = sum(target_distribution.values())
        target_proportions = {qtype: count / total_target
                              for qtype, count in target_distribution.items()}

        for qtype, proportion in sorted(target_proportions.items(), key=lambda x: x[1], reverse=True):
            print(f"target proportions: {qtype} {proportion}")
        allocated_counts = {}
        total_allocated = 0


                       
        for qtype in target_distribution.keys():
            target_count = target_distribution[qtype]

            if qtype in qas_by_type:
                available_count = len(qas_by_type[qtype])
                actual_inherit = min(target_count, available_count)

                if actual_inherit > 0:
                    allocated_counts[qtype] = actual_inherit
                    total_allocated += actual_inherit

            else:
                print(f"qtype not in qas_by_type: {qtype}")
                                     
        target_total_inherit = min(int(target_sample_size * 0.6), total_allocated + 20)                   
        remaining_budget = max(0, target_total_inherit - total_allocated)

        if remaining_budget > 0:
                               
            need_more = []
            for qtype in target_distribution.keys():
                if qtype in qas_by_type:
                    current_inherit = allocated_counts.get(qtype, 0)
                    target_count = target_distribution[qtype]
                    available_count = len(qas_by_type[qtype])
                    can_add_more = min(available_count - current_inherit, target_count - current_inherit)

                    if can_add_more > 0:
                        need_more.append((qtype, can_add_more, target_count))

                                   
            need_more.sort(key=lambda x: x[2], reverse=True)

            for qtype, can_add, target_count in need_more:
                if remaining_budget <= 0:
                    break

                additional = min(can_add, remaining_budget)
                if additional > 0:
                    allocated_counts[qtype] = allocated_counts.get(qtype, 0) + additional
                    remaining_budget -= additional

        for qtype in sorted(target_distribution.keys(), key=lambda x: target_distribution[x], reverse=True):
            inherit_count = allocated_counts.get(qtype, 0)
            target_count = target_distribution[qtype]
            available_count = len(qas_by_type.get(qtype, []))
            inherit_ratio = inherit_count / target_count * 100 if target_count > 0 else 0


                    
        import random
        selected_qas = []

        for qtype, count in allocated_counts.items():
            available_qas = qas_by_type[qtype]
            if count >= len(available_qas):
                                      
                selected_qas.extend(available_qas)
            else:
                          
                sampled = random.sample(available_qas, count)
                selected_qas.extend(sampled)

        return selected_qas
