                           

import json
import random
from pathlib import Path
from typing import List, Dict, Any


class QuestionSampler:
                

    def __init__(self, template_file: Path):
           
        self.template_file = template_file
        self.templates = self._load_question_templates()

    def _load_question_templates(self) -> List[Dict[str, Any]]:
                      
        if not self.template_file.exists():
            exit(1)


        templates = []
        try:
            with open(self.template_file, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    if line.strip():
                        try:
                            data = json.loads(line.strip())
                            templates.append(data)
                        except json.JSONDecodeError as e:
                            continue

            return templates

        except Exception as e:
            exit(1)

    def sample_templates_by_distribution(self,
                                         segment: Dict[str, Any],
                                         target_sample_size: int = 150,
                                         test_mode: bool = False,
                                         test_size: int = 10,
                                         inherited_qas_by_type: Dict[str, int] = None) -> List[Dict[str, Any]]:
           
        segment_id = segment['segment_id']
        type_distribution = segment['type_distribution']

        if inherited_qas_by_type is None:
            inherited_qas_by_type = {}


        if inherited_qas_by_type:
            total_inherited = sum(inherited_qas_by_type.values())
            for qtype, count in sorted(inherited_qas_by_type.items(), key=lambda x: x[1], reverse=True):
                print(f"inherited qas by type: {qtype} {count}")
                           
        original_total = sum(type_distribution.values())
        type_proportions = {qtype: count / original_total for qtype, count in type_distribution.items()}

        for qtype, proportion in sorted(type_proportions.items(), key=lambda x: x[1], reverse=True):
            print(f"type proportions: {qtype} {proportion}")
                   
        templates_by_type = {}
        for template in self.templates:
            qtype = template.get('question_type', 'UNCATEGORIZED')
            if qtype not in templates_by_type:
                templates_by_type[qtype] = []
            templates_by_type[qtype].append(template)

        for qtype, templates in templates_by_type.items():
            print(f"templates by type: {qtype} {len(templates)}")
                        
                                  
        total_target_with_inherited = target_sample_size + sum(inherited_qas_by_type.values())
        ideal_counts = {}
        for qtype, proportion in type_proportions.items():
            ideal_counts[qtype] = int(total_target_with_inherited * proportion)
            
                        
        allocated_counts = {}
        remaining_target = target_sample_size

        sorted_types_by_priority = sorted(type_proportions.items(), key=lambda x: x[1], reverse=True)

        for qtype, proportion in sorted_types_by_priority:
            if qtype in templates_by_type and remaining_target > 0:
                ideal_total = ideal_counts.get(qtype, 0)
                already_inherited = inherited_qas_by_type.get(qtype, 0)
                need_to_generate = max(0, ideal_total - already_inherited)

                            
                need_to_generate = min(need_to_generate, remaining_target)

                if need_to_generate > 0:
                    allocated_counts[qtype] = need_to_generate
                    remaining_target -= need_to_generate


                                
        if remaining_target > 0:
            sorted_types = sorted(type_proportions.items(), key=lambda x: x[1], reverse=True)
            while remaining_target > 0 and any(qtype in templates_by_type for qtype, _ in sorted_types):
                for qtype, _ in sorted_types:
                    if remaining_target <= 0:
                        break
                    if qtype in templates_by_type:
                        if qtype not in allocated_counts:
                            allocated_counts[qtype] = 0
                        allocated_counts[qtype] += 1
                        remaining_target -= 1

        for qtype in sorted(allocated_counts.keys()):
            original_count = type_distribution.get(qtype, 0)
            inherited_count = inherited_qas_by_type.get(qtype, 0)
            allocated = allocated_counts[qtype]
            available = len(templates_by_type.get(qtype, []))
            total_final = inherited_count + allocated

                 
        sampled_templates = []
        actual_sampled = {}

        for qtype, target_count in allocated_counts.items():
            available_templates = templates_by_type.get(qtype, [])

            if not available_templates:
                continue

                      
            sample_count = min(target_count, len(available_templates))

                  
            if sample_count >= len(available_templates):
                sampled = available_templates.copy()
            else:
                sampled = random.sample(available_templates, sample_count)

            sampled_templates.extend(sampled)
            actual_sampled[qtype] = len(sampled)

                    
        if test_mode:
            if len(sampled_templates) > test_size:
                sampled_templates = random.sample(sampled_templates, test_size)

        total_sampled = len(sampled_templates)
        for qtype in sorted(actual_sampled.keys()):
            count = actual_sampled[qtype]
            percentage = count / total_sampled * 100 if total_sampled > 0 else 0


        return sampled_templates

    def get_templates_by_type(self, question_type: str) -> List[Dict[str, Any]]:
                         
        return [t for t in self.templates if t.get('question_type') == question_type]

    def get_random_template(self, question_type: str = None) -> Dict[str, Any]:
                    
        if question_type:
            available_templates = self.get_templates_by_type(question_type)
        else:
            available_templates = self.templates

        if not available_templates:
            raise ValueError(f"No templates found for type {question_type}")

        return random.choice(available_templates)
