import json
import os
import argparse
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from typing import List, Dict, Any
from pathlib import Path


class ChronoPlayRetrievalEvaluator:
                                   

    def __init__(self, config: Dict):
        self.config = config
        self.game_name = config.get('game_name', 'dyinglight2')
        self.target_segment_id = config.get('target_segment_id', None)

    def load_retrieval_results(self, retrieval_results_path: str) -> List[Dict]:
                      

        retrieval_data = []
        with open(retrieval_results_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line.strip())

                              
                    original_qa_data = item.get('original_qa_data', {})

                                                        
                    ground_truth_docs = original_qa_data.get('retrieved_docs', [])
                    ground_truth_doc_ids = []

                                                                
                    if isinstance(ground_truth_docs, list):
                        for doc in ground_truth_docs:
                            if isinstance(doc, dict) and 'metadata' in doc and 'id' in doc['metadata']:
                                ground_truth_doc_ids.append(doc['metadata']['id'])

                                                   
                    retrieved_docs = item.get('retrieved_docs', [])
                    retrieved_doc_ids = []
                    for doc in retrieved_docs:
                        if isinstance(doc, dict) and 'metadata' in doc and 'id' in doc['metadata']:
                            retrieved_doc_ids.append(doc['metadata']['id'])

                            
                    retrieval_config = item.get('retrieval_config', {})
                    retrieval_method = retrieval_config.get('retrieval_method', 'vector')
                    embedding_model = retrieval_config.get('embedding_model', 'unknown')
                    embedding_service = retrieval_config.get('embedding_service', 'unknown')
                    top_k = retrieval_config.get('top_k', 5)

                            
                    if retrieval_method == 'bm25':
                        model_key = f"bm25_k{top_k}"
                        model_name = "BM25"
                        service_name = "BM25"
                    else:
                        model_key = f"{embedding_model}_k{top_k}"
                        model_name = embedding_model
                        service_name = embedding_service

                    eval_item = {
                        'question': item.get('question', ''),
                        'retrieved_docs': retrieved_docs,
                        'retrieved_doc_ids': retrieved_doc_ids,                    
                        'ground_truth_doc_ids': ground_truth_doc_ids,
                        'ground_truth_docs': ground_truth_docs,
                        'retrieval_time': item.get('retrieval_time', 0),
                        'config': retrieval_config,
                        'metadata': {
                            'question_type': original_qa_data.get('question_type', 'unknown'),
                            'difficulty': original_qa_data.get('difficulty', 'unknown'),
                            'segment_id': retrieval_config.get('segment_id'),
                            'k': top_k,
                            'retrieval_method': retrieval_method,
                            'embedding_model': model_name,
                            'embedding_service': service_name,
                            'model_key': model_key
                        }
                    }
                    retrieval_data.append(eval_item)

        return retrieval_data

    def evaluate_retrieval_metrics(self, retrieval_data: List[Dict]) -> Dict[str, Any]:
                                                     

                 
        k_values = [1, 3, 5]

                      
        metrics_by_k = {}
        for k in k_values:
            metrics_by_k[k] = {
                'recall_scores': [],
                'f1_scores': [],
                'ndcg_scores': []
            }

        valid_queries = 0

        for i, item in enumerate(retrieval_data):
                         
            retrieved_doc_ids = item.get('retrieved_doc_ids', [])
            ground_truth_doc_ids = item.get('ground_truth_doc_ids', [])

                   
            retrieved_ids = [doc_id for doc_id in retrieved_doc_ids if doc_id]
            relevant_ids = [doc_id for doc_id in ground_truth_doc_ids if doc_id]

            if len(relevant_ids) == 0:
                continue

            valid_queries += 1
            relevant_ids_set = set(relevant_ids)

                       
            for k in k_values:
                          
                retrieved_ids_at_k = retrieved_ids[:k]
                retrieved_ids_set_at_k = set(retrieved_ids_at_k)
                intersection = retrieved_ids_set_at_k & relevant_ids_set

                          
                recall_at_k = len(intersection) / len(relevant_ids) if relevant_ids else 0.0
                metrics_by_k[k]['recall_scores'].append(recall_at_k)

                                     
                precision_at_k = len(intersection) / len(retrieved_ids_at_k) if retrieved_ids_at_k else 0.0

                      
                f1_at_k = 2 * precision_at_k * recall_at_k / \
                    (precision_at_k + recall_at_k) if (precision_at_k + recall_at_k) > 0 else 0.0
                metrics_by_k[k]['f1_scores'].append(f1_at_k)

                        
                ndcg_at_k = self._calculate_ndcg(retrieved_ids_at_k, relevant_ids_set, k)
                metrics_by_k[k]['ndcg_scores'].append(ndcg_at_k)

                   
        results = {
            'total_queries': len(retrieval_data),
            'valid_queries': valid_queries,
            'coverage': valid_queries / len(retrieval_data) if retrieval_data else 0.0
        }

                     
        for k in k_values:
            recall_scores = metrics_by_k[k]['recall_scores']
            f1_scores = metrics_by_k[k]['f1_scores']
            ndcg_scores = metrics_by_k[k]['ndcg_scores']

                       
            results[f'recall_at_{k}'] = np.mean(recall_scores) if recall_scores else 0.0
            results[f'f1_at_{k}'] = np.mean(f1_scores) if f1_scores else 0.0
            results[f'ndcg_at_{k}'] = np.mean(ndcg_scores) if ndcg_scores else 0.0

        return results

    def _calculate_ndcg(self, retrieved_ids: List[str], relevant_ids_set: set, k: int) -> float:
                      
        if not retrieved_ids or not relevant_ids_set:
            return 0.0

                 
        dcg = 0.0
        for i, doc_id in enumerate(retrieved_ids[:k]):
            if doc_id in relevant_ids_set:
                                      
                dcg += 1.0 / np.log2(i + 2)                                

                           
                                  
        num_relevant = len(relevant_ids_set)
        idcg = 0.0
        for i in range(min(k, num_relevant)):
            idcg += 1.0 / np.log2(i + 2)

                
        if idcg == 0:
            return 0.0

        return dcg / idcg

    def analyze_performance(self, retrieval_data: List[Dict]) -> Dict[str, Any]:
                             

        k_values = [1, 3, 5]

                           
        type_performance = defaultdict(lambda: defaultdict(list))
        difficulty_performance = defaultdict(lambda: defaultdict(list))
        segment_performance = defaultdict(lambda: defaultdict(list))
        retrieval_times = []

        for item in retrieval_data:
            metadata = item.get('metadata', {})
            question_type = metadata.get('question_type', 'unknown')
            difficulty = metadata.get('difficulty', 'unknown')
            segment_id = metadata.get('segment_id', 'unknown')
            retrieval_time = item.get('retrieval_time', 0)

                       
            retrieved_doc_ids = item.get('retrieved_doc_ids', [])
            ground_truth_doc_ids = item.get('ground_truth_doc_ids', [])

            retrieved_ids = [doc_id for doc_id in retrieved_doc_ids if doc_id]
            relevant_ids = set([doc_id for doc_id in ground_truth_doc_ids if doc_id])

            if relevant_ids:
                           
                for k in k_values:
                    retrieved_ids_at_k = retrieved_ids[:k]
                    retrieved_ids_set_at_k = set(retrieved_ids_at_k)
                    intersection = retrieved_ids_set_at_k & relevant_ids

                            
                    precision = len(intersection) / len(retrieved_ids_at_k) if retrieved_ids_at_k else 0.0
                    recall = len(intersection) / len(relevant_ids)
                    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
                    ndcg = self._calculate_ndcg(retrieved_ids_at_k, relevant_ids, k)

                           
                    performance = {'recall': recall, 'f1': f1, 'ndcg': ndcg}
                    type_performance[question_type][k].append(performance)
                    difficulty_performance[difficulty][k].append(performance)
                    segment_performance[str(segment_id)][k].append(performance)

            if retrieval_time > 0:
                retrieval_times.append(retrieval_time)

                
        analysis = {
            'by_question_type': self._aggregate_performance_multi_k(type_performance, k_values),
            'by_difficulty': self._aggregate_performance_multi_k(difficulty_performance, k_values),
            'by_segment': self._aggregate_performance_multi_k(segment_performance, k_values),
            'timing_analysis': {
                'avg_retrieval_time': np.mean(retrieval_times) if retrieval_times else 0.0,
                'median_retrieval_time': np.median(retrieval_times) if retrieval_times else 0.0,
                'total_samples': len(retrieval_times)
            }
        }

        return analysis

    def _aggregate_performance(self, performance_dict: Dict) -> Dict:
                                
        aggregated = {}
        for category, samples in performance_dict.items():
            if samples:
                aggregated[category] = {
                    'count': len(samples),
                    'avg_recall': np.mean([s['recall'] for s in samples]),
                    'avg_f1': np.mean([s['f1'] for s in samples]),
                    'avg_ndcg': np.mean([s['ndcg'] for s in samples])
                }
        return aggregated

    def _aggregate_performance_multi_k(self, performance_dict: Dict, k_values: List[int]) -> Dict:
                       
        aggregated = {}
        for category, k_data in performance_dict.items():
            if k_data:
                aggregated[category] = {}

                             
                for k in k_values:
                    samples = k_data.get(k, [])
                    if samples:
                        avg_recall = np.mean([s['recall'] for s in samples])
                        avg_f1 = np.mean([s['f1'] for s in samples])
                        avg_ndcg = np.mean([s['ndcg'] for s in samples])

                        aggregated[category][f'k_{k}'] = {
                            'count': len(samples),
                            'avg_recall': avg_recall,
                            'avg_f1': avg_f1,
                            'avg_ndcg': avg_ndcg
                        }

                        
                total_samples = len(k_data.get(k_values[0], []))                
                aggregated[category]['total_count'] = total_samples

        return aggregated

    def evaluate_segments(self, retrieval_results_dir: str, output_dir: str = None,
                          segment_ids: List[int] = None) -> Dict[str, Any]:
                                 
                  
        if segment_ids is None:
                       
            results_path = Path(retrieval_results_dir)
            segment_ids = []

                      
            patterns = [
                f"retrieval_{self.game_name}_segment_*_*.jsonl"
            ]

            for pattern in patterns:
                for file in results_path.glob(pattern):
                    try:
                                    
                        parts = file.stem.split('_')
                        for i, part in enumerate(parts):
                            if part == 'segment' and i + 1 < len(parts):
                                segment_id = int(parts[i + 1])
                                if segment_id not in segment_ids:
                                    segment_ids.append(segment_id)
                                break
                    except (ValueError, IndexError):
                        continue
            segment_ids.sort()

        if self.target_segment_id:
            segment_ids = [self.target_segment_id]

                  
        model_results = {}
        all_segments_summary = {}


        for segment_id in tqdm(segment_ids, desc="Retrieval segment evaluation"):

                      
            retrieval_files = list(Path(retrieval_results_dir).glob(
                f"retrieval_{self.game_name}_segment_{segment_id}_*.jsonl"))

            if not retrieval_files:
                all_segments_summary[segment_id] = {
                    'segment_id': segment_id,
                    'error': 'No retrieval results file found'
                }
                continue

                       
            segment_model_results = {}

            for retrieval_file in retrieval_files:

                        
                try:
                    retrieval_data = self.load_retrieval_results(str(retrieval_file))

                    if not retrieval_data:
                        continue

                             
                    model_groups = defaultdict(list)
                    for item in retrieval_data:
                        model_key = item['metadata']['model_key']
                        model_groups[model_key].append(item)


                                
                    for model_key, model_data in model_groups.items():

                                
                        results = {}
                        results['metrics'] = self.evaluate_retrieval_metrics(model_data)
                        results['analysis'] = self.analyze_performance(model_data)
                        results['model_info'] = {
                            'retrieval_method': model_data[0]['metadata']['retrieval_method'],
                            'embedding_model': model_data[0]['metadata']['embedding_model'],
                            'embedding_service': model_data[0]['metadata']['embedding_service'],
                            'top_k': model_data[0]['metadata']['k'],
                            'model_key': model_key
                        }

                                
                        if output_dir:
                                         
                            model_output_dir = Path(output_dir) / model_key
                            model_output_dir.mkdir(parents=True, exist_ok=True)

                            output_file = model_output_dir / \
                                f"{self.game_name}_segment_{segment_id}_retrieval_evaluation.json"

                            with open(output_file, 'w', encoding='utf-8') as f:
                                json.dump(results, f, ensure_ascii=False, indent=2)

                                
                        if model_key not in model_results:
                            model_results[model_key] = {}

                        model_results[model_key][segment_id] = {
                            'segment_id': segment_id,
                            'data_count': len(model_data),
                            'metrics': results.get('metrics', {}),
                            'analysis': results.get('analysis', {}),
                            'model_info': results.get('model_info', {}),
                            'output_file': str(output_file) if output_dir else None
                        }

                        segment_model_results[model_key] = {
                            'data_count': len(model_data),
                            'metrics': results.get('metrics', {}),
                            'model_info': results.get('model_info', {})
                        }


                except Exception as e:
                    print(f"Error evaluating retrieval metrics for segment {segment_id}: {e}")
            
            if segment_model_results:
                all_segments_summary[segment_id] = {
                    'segment_id': segment_id,
                    'models': segment_model_results,
                    'total_models': len(segment_model_results)
                }
            else:
                all_segments_summary[segment_id] = {
                    'segment_id': segment_id,
                    'error': 'No valid retrieval data found'
                }

                     
        if output_dir:
            for model_key, model_data in model_results.items():
                model_summary_file = Path(output_dir) / model_key / \
                    f"{self.game_name}_retrieval_evaluation_summary.json"
                model_summary_file.parent.mkdir(parents=True, exist_ok=True)

                with open(model_summary_file, 'w', encoding='utf-8') as f:
                    json.dump(model_data, f, ensure_ascii=False, indent=2)

                      
            overall_summary_file = Path(output_dir) / f"{self.game_name}_retrieval_evaluation_overall_summary.json"
            with open(overall_summary_file, 'w', encoding='utf-8') as f:
                json.dump(all_segments_summary, f, ensure_ascii=False, indent=2)

              
        self._print_model_summary(model_results, all_segments_summary)

        return model_results

    def _print_model_summary(self, model_results: Dict, all_segments_summary: Dict):
                            

        total_models = len(model_results)
        total_segments = len(all_segments_summary)
        successful_segments = len([s for s in all_segments_summary.values() if 'error' not in s])


                   
        for model_key, model_data in model_results.items():

                    
            if model_data:
                first_segment = next(iter(model_data.values()))
                model_info = first_segment.get('model_info', {})
                retrieval_method = model_info.get('retrieval_method', 'unknown')
      
            model_metrics_aggregated = defaultdict(list)
            total_data_count = 0
            successful_model_segments = 0

            for segment_id, segment_data in model_data.items():
                if 'error' not in segment_data:
                    successful_model_segments += 1
                    total_data_count += segment_data.get('data_count', 0)

                          
                    metrics = segment_data.get('metrics', {})
                    for metric, score in metrics.items():
                        if isinstance(score, (int, float)):
                            model_metrics_aggregated[metric].append(score)


                    
            if model_metrics_aggregated:
                for k in [1, 3, 5]:
                    k_metrics = []
                    for metric_type in ['recall', 'f1', 'ndcg']:
                        metric_key = f'{metric_type}_at_{k}'
                        if metric_key in model_metrics_aggregated:
                            scores = model_metrics_aggregated[metric_key]
                            if scores:
                                avg_score = np.mean(scores)
                                k_metrics.append(f"{metric_type.upper()}@{k}: {avg_score:.4f}")

                    if k_metrics:
                        print(f"{model_key} - {metric_type.upper()}@{k}: {avg_score:.4f}")

                     
            for segment_id, segment_data in sorted(model_data.items()):
                data_count = segment_data.get('data_count', 0)

                if 'error' in segment_data:
                    print(f"{segment_id} - Error in retrieval metrics")
                else:
                    metrics = segment_data.get('metrics', {})
                                        
                    k_results = []
                    for k in [1, 3, 5]:
                        recall = metrics.get(f'recall_at_{k}', 0)
                        f1 = metrics.get(f'f1_at_{k}', 0)
                        ndcg = metrics.get(f'ndcg_at_{k}', 0)
                        k_results.append(f"R@{k}:{recall:.3f} F1@{k}:{f1:.3f} NDCG@{k}:{ndcg:.3f}")

                    results_str = " | ".join(k_results)


    def _print_summary(self, results_summary: Dict):
                      

        total_segments = len(results_summary)
        successful_segments = 0
        total_data_count = 0

              
        metrics_aggregated = defaultdict(list)

        for segment_id, summary in results_summary.items():
            if 'error' not in summary:
                successful_segments += 1
                total_data_count += summary.get('data_count', 0)

                      
                metrics = summary.get('metrics', {})
                for metric, score in metrics.items():
                    if isinstance(score, (int, float)):
                        metrics_aggregated[metric].append(score)


                 
        for segment_id, summary in sorted(results_summary.items()):
            data_count = summary.get('data_count', 0)

            if 'error' in summary:
                print(f"{segment_id} - Error in retrieval metrics")
            else:
                metrics = summary.get('metrics', {})
                                    
                k_results = []
                for k in [1, 3, 5]:
                    recall = metrics.get(f'recall_at_{k}', 0)
                    f1 = metrics.get(f'f1_at_{k}', 0)
                    ndcg = metrics.get(f'ndcg_at_{k}', 0)
                    k_results.append(f"R@{k}:{recall:.3f} F1@{k}:{f1:.3f} NDCG@{k}:{ndcg:.3f}")

                results_str = " | ".join(k_results)

                
        if metrics_aggregated:
                       
            for k in [1, 3, 5]:
                for metric_type in ['recall', 'f1', 'ndcg']:
                    metric_key = f'{metric_type}_at_{k}'
                    if metric_key in metrics_aggregated:
                        scores = metrics_aggregated[metric_key]
                        if scores:
                            avg_score = np.mean(scores)


    def run_single_evaluation(self, retrieval_results_path: str, output_path: str = None) -> Dict:
                       

                
        retrieval_data = self.load_retrieval_results(retrieval_results_path)

        if not retrieval_data:
            return {}

        results = {}
        results['metrics'] = self.evaluate_retrieval_metrics(retrieval_data)
        results['analysis'] = self.analyze_performance(retrieval_data)

              
        if output_path:
            output_dir = os.path.dirname(output_path)
            if output_dir:
                os.makedirs(output_dir, exist_ok=True)
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, ensure_ascii=False, indent=2)

              
        self._print_single_summary(results)

        return results

    def _print_single_summary(self, results: Dict):
                         

        if 'metrics' in results:
            metrics = results['metrics']

                       
            for k in [1, 3, 5]:
                for metric_type in ['recall', 'f1', 'ndcg']:
                    metric_key = f'{metric_type}_at_{k}'
                    if metric_key in metrics:
                        score = metrics[metric_key]

                      

        if 'analysis' in results and 'timing_analysis' in results['analysis']:
            timing = results['analysis']['timing_analysis']


def main():
             
    parser = argparse.ArgumentParser(
        description='ChronoPlay Retrieval Effectiveness Evaluation',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Usage examples:
  # Automatically evaluate retrieval effectiveness for all segments
  python retrieval_evaluator.py

  # Evaluate specific segment
  python retrieval_evaluator.py --segment_id 1

  # Evaluate all segments for specific game
  python retrieval_evaluator.py --game dyinglight2

  # Single file evaluation
  python retrieval_evaluator.py --retrieval_results ./results/retrieval_dyinglight2_segment_1_*.jsonl
        """
    )

          
    parser.add_argument('--game', type=str, default='dyinglight2',
                        help='Game name (default: dyinglight2)')
    parser.add_argument('--segment_id', type=int,
                        help='Target segment ID (if not specified, evaluate all available segments)')

            
    parser.add_argument('--retrieval_results', type=str,
                        help='Single retrieval result file path (if specified, enter single file mode)')
    parser.add_argument('--results_dir', type=str,
                        default='./retrieval_results',
                        help='Retrieval results directory (default: ./retrieval_results)')
    parser.add_argument('--output_dir', type=str,
                        default='./retrieval_evaluation',
                        help='Evaluation results output directory (default: ./retrieval_evaluation)')

    args = parser.parse_args()


    if args.retrieval_results:
        mode = "single_file"
        retrieval_file = args.retrieval_results

        if not os.path.exists(retrieval_file):
            return
    else:
        mode = "batch"
        results_dir = args.results_dir

        if not os.path.exists(results_dir):
            return

           
    config = {
        'game_name': args.game,
        'target_segment_id': args.segment_id
    }
    evaluator = ChronoPlayRetrievalEvaluator(config)

            
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    try:
        if mode == "single_file":

            output_file = os.path.join(output_dir, f"{args.game}_single_retrieval_evaluation.json")
            evaluator.run_single_evaluation(retrieval_file, output_file)

        else:

            evaluator.evaluate_segments(results_dir, output_dir)

    except KeyboardInterrupt:
        print("Evaluation interrupted by user")
    except Exception as e:
        print(f"Error during evaluation: {e}")


if __name__ == '__main__':
    main()
