                       
\
\
\
   

import json
import time
import asyncio
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any

from .config import GenerationConfig
from .generator import TextGenerator


class GenerationPipeline:
                       

    def __init__(self, config: GenerationConfig):
        self.config = config
        self.generator = TextGenerator(config)

        if config.verbose:
            print("GenerationPipeline initialized")
    def generate_single(self, retrieval_result: Dict[str, Any]) -> Dict[str, Any]:
                          
        start_time = time.time()

        question = retrieval_result['question']
        retrieved_docs = retrieval_result['retrieved_docs']

        if self.config.verbose:
            print("Generating single")
              
        answer = self.generator.generate(question, retrieved_docs)

                
        result = {
            'question': question,
            'answer': answer,
            'generation_time': time.time() - start_time,
            'timestamp': datetime.now().isoformat(),
            'generation_config': {
                'llm_model': self.config.llm_model,
                'temperature': self.config.temperature
            },
                    
            'retrieved_docs': retrieved_docs,
            'retrieved_doc_ids': retrieval_result['retrieved_doc_ids'],
            'retrieval_time': retrieval_result.get('retrieval_time', 0),
            'retrieval_config': retrieval_result.get('retrieval_config', {}),
                   
            'total_processing_time': retrieval_result.get('retrieval_time', 0) + (time.time() - start_time)
        }

                           
        if 'ground_truth_answer' in retrieval_result:
            result.update({
                'ground_truth_answer': retrieval_result['ground_truth_answer'],
                'ground_truth_doc_ids': retrieval_result['ground_truth_doc_ids'],
                'ground_truth_docs': retrieval_result.get('ground_truth_docs', [])
            })

        if 'question_index' in retrieval_result:
            result['question_index'] = retrieval_result['question_index']

        if self.config.verbose:
            print("Generated single")
        return result

    async def generate_single_async(self, retrieval_result: Dict[str, Any], index: int = 0) -> Dict[str, Any]:
                            
        start_time = time.time()

        question = retrieval_result['question']
        retrieved_docs = retrieval_result['retrieved_docs']

        if self.config.verbose:
            print("Generating single async")
                
        answer = await self.generator.generate_async(question, retrieved_docs)

                
        result = {
            'question': question,
            'answer': answer,
            'generation_time': time.time() - start_time,
            'timestamp': datetime.now().isoformat(),
            'generation_config': {
                'llm_model': self.config.llm_model,
                'temperature': self.config.temperature
            },
                    
            'retrieved_docs': retrieved_docs,
            'retrieved_doc_ids': retrieval_result['retrieved_doc_ids'],
            'retrieval_time': retrieval_result.get('retrieval_time', 0),
            'retrieval_config': retrieval_result.get('retrieval_config', {}),
                   
            'total_processing_time': retrieval_result.get('retrieval_time', 0) + (time.time() - start_time)
        }

                           
        if 'ground_truth_answer' in retrieval_result:
            result.update({
                'ground_truth_answer': retrieval_result['ground_truth_answer'],
                'ground_truth_doc_ids': retrieval_result['ground_truth_doc_ids'],
                'ground_truth_docs': retrieval_result.get('ground_truth_docs', [])
            })

        if 'question_index' in retrieval_result:
            result['question_index'] = retrieval_result['question_index']

        if self.config.verbose:
            print("Generated single async")
        return result

    def batch_generate(self, retrieval_results: List[Dict[str, Any]], output_file: str = None) -> List[Dict[str, Any]]:
                             
                  
        return asyncio.run(self.batch_generate_parallel(retrieval_results, output_file))

    async def batch_generate_parallel(
            self, retrieval_results: List[Dict[str, Any]], output_file: str = None) -> List[Dict[str, Any]]:
                      
        total = len(retrieval_results)
        all_results = []

        if self.config.verbose:
            print("Batch generating")
              
        for batch_start in range(0, total, self.config.concurrent_requests):
            batch_end = min(batch_start + self.config.concurrent_requests, total)
            batch_retrieval_results = retrieval_results[batch_start:batch_end]

            if self.config.verbose:
                print("Batch generating")
                    
            tasks = []
            for i, retrieval_result in enumerate(batch_retrieval_results):
                global_index = batch_start + i
                task = self._generate_single_with_error_handling(retrieval_result, global_index)
                tasks.append(task)

                  
            batch_results = await asyncio.gather(*tasks, return_exceptions=True)

                  
            for i, result in enumerate(batch_results):
                global_index = batch_start + i

                if isinstance(result, Exception):
                          
                    if self.config.verbose:
                        print(f"Error generating for batch {i}: {result}")
                    error_result = {
                        'question': batch_retrieval_results[i].get('question', ''),
                        'answer': f"Generation failed: {str(result)}",
                        'error': str(result),
                        'timestamp': datetime.now().isoformat(),
                        'retrieved_docs': batch_retrieval_results[i].get('retrieved_docs', []),
                        'retrieved_doc_ids': batch_retrieval_results[i].get('retrieved_doc_ids', [])
                    }

                            
                    if 'question_index' in batch_retrieval_results[i]:
                        error_result['question_index'] = batch_retrieval_results[i]['question_index']
                    if 'ground_truth_answer' in batch_retrieval_results[i]:
                        error_result['ground_truth_answer'] = batch_retrieval_results[i]['ground_truth_answer']
                        error_result['ground_truth_doc_ids'] = batch_retrieval_results[i]['ground_truth_doc_ids']

                    all_results.append(error_result)

                           
                    if output_file:
                        with open(output_file, 'a', encoding='utf-8') as f:
                            f.write(json.dumps(error_result, ensure_ascii=False) + '\n')
                else:
                          
                    all_results.append(result)

                           
                    if output_file:
                        with open(output_file, 'a', encoding='utf-8') as f:
                            f.write(json.dumps(result, ensure_ascii=False) + '\n')

        if self.config.verbose:
            print("Batch generated")
        return all_results

    async def _generate_single_with_error_handling(
            self, retrieval_result: Dict[str, Any], index: int) -> Dict[str, Any]:
                          
        try:
            return await self.generate_single_async(retrieval_result, index)
        except Exception as e:
                           
            raise e

    def generate_from_file(self, retrieval_file: str, output_file: str = None) -> List[Dict[str, Any]]:
                         
        if self.config.verbose:
            print("Generating from file")
                
        retrieval_results = []
        try:
            with open(retrieval_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        try:
                            data = json.loads(line.strip())
                            retrieval_results.append(data)
                        except json.JSONDecodeError:
                            continue

            if self.config.verbose:
                print("Generated from file")
        except Exception as e:
            if self.config.verbose:
                print(f"Error loading retrieval results: {e}")  
            return []

              
        return self.batch_generate(retrieval_results, output_file)

    @staticmethod
    def load_retrieval_results(retrieval_file: str) -> List[Dict[str, Any]]:
                      
        results = []
        try:
            with open(retrieval_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        try:
                            data = json.loads(line.strip())
                            results.append(data)
                        except json.JSONDecodeError:
                            continue
        except Exception as e:
            print(f"Error loading retrieval results: {e}")  
        return results

    def create_generation_experiment(self,
                                     retrieval_file: str,
                                     experiment_name: str = None,
                                     output_dir: str = "./results") -> str:
                             
        if experiment_name is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_name = self.config.llm_model.replace('/', '_').replace('-', '_')
            experiment_name = f"generation_{model_name}_{timestamp}"

        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        output_file = output_dir / f"{experiment_name}.jsonl"

        if self.config.verbose:
            print("Creating generation experiment")
              
        results = self.generate_from_file(retrieval_file, str(output_file))

        if self.config.verbose:
            print("Generated from file")
        return str(output_file)
