import json
import os

# MODIFY THIS PATH
# base model
# result_folder = "folder_evaluation_results/results_base_model_hypothesis_composition"
# lora training
# result_folder = "folder_evaluation_results/results_lora_training_hypothesis_composition_196k_42_seed"
# lora training with curriculum learning
# result_folder = "folder_evaluation_results/results_lora_training_hypothesis_composition_196k_curriculum_learning_42_seed"
# full training
result_folder = "folder_evaluation_results/results_full_training_hypothesis_composition_196k_curriculum_learning"

def summarize_results(result_path):
    """
    Summarize results from either the new folder structure or legacy single file.
    
    New folder structure contains:
      - metrics.json: evaluation metrics for each sample
      - generations.json: generated and ground truth hypotheses
      - summary.json: pre-computed summary statistics
    """
    # Check if it's a folder with new structure
    if os.path.isdir(result_path):
        summary_path = os.path.join(result_path, 'summary.json')
        metrics_path = os.path.join(result_path, 'metrics.json')
        generations_path = os.path.join(result_path, 'generations.json')
        
        # Read summary.json for quick stats
        if os.path.exists(summary_path):
            with open(summary_path, 'r') as f:
                summary = json.load(f)
            
            print("="*60)
            print("EVALUATION SUMMARY")
            print("="*60)
            print(f"Average weighted score: {summary.get('average_weighted_score', 'N/A'):.4f}" if summary.get('average_weighted_score') else "Average weighted score: N/A")
            
            # Handle both old field name and new field name for backwards compatibility
            avg_components = summary.get('average_num_gt_components') or summary.get('average_length_components')
            print(f"Average num GT components: {avg_components:.2f}" if avg_components else "Average num GT components: N/A")
            
            avg_hyp_len = summary.get('average_hypothesis_length')
            print(f"Average hypothesis length: {avg_hyp_len:.1f} words" if avg_hyp_len else "Average hypothesis length: N/A")
            
            print(f"Number of results: {summary.get('total_evaluations', 'N/A')}")
            print(f"Min score: {summary.get('min_score', 'N/A'):.4f}" if summary.get('min_score') else "Min score: N/A")
            print(f"Max score: {summary.get('max_score', 'N/A'):.4f}" if summary.get('max_score') else "Max score: N/A")
            print(f"Extraction failures: {summary.get('extraction_failures', 0)}")
            print(f"Total evaluations attempted: {summary.get('total_evaluations_attempted', 'N/A')}")
        
        # Also show file info
        print("\n" + "="*60)
        print("FILE INFORMATION")
        print("="*60)
        if os.path.exists(metrics_path):
            with open(metrics_path, 'r') as f:
                metrics = json.load(f)
            print(f"metrics.json: {len(metrics)} successful evaluations")
        
        if os.path.exists(generations_path):
            with open(generations_path, 'r') as f:
                generations = json.load(f)
            failed_count = sum(1 for g in generations if g.get('extraction_failed', False))
            print(f"generations.json: {len(generations)} total entries ({failed_count} with extraction failures)")
            
            # Show a successful sample
            successful = [g for g in generations if not g.get('extraction_failed', False)]
            if successful:
                print("\nSample generation (first successful entry):")
                first = successful[0]
                print(f"  File: {first.get('file', 'N/A')}")
                print(f"  Score: {first.get('weighted_score', 'N/A')}")
                gen_hyp = first.get('generated_hypothesis', '')
                gt_hyp = first.get('ground_truth_hypothesis', '')
                print(f"  Generated (truncated): {gen_hyp[:200] if gen_hyp else 'N/A'}...")
                print(f"  Ground truth (truncated): {gt_hyp[:200] if gt_hyp else 'N/A'}...")
    
    # Legacy single file format
    elif os.path.isfile(result_path):
        with open(result_path, 'r') as f:
            results = json.load(f)
        
        # Calculate the average weighted score
        valid_scores = [result['weighted_score'] for result in results if result.get('weighted_score') is not None]
        valid_components = [len(result['eval_results']) for result in results if result.get('eval_results')]
        valid_hyp_lengths = [len(result['generated_hypothesis'].split()) for result in results 
                            if result.get('generated_hypothesis') and not result.get('extraction_failed', False)]
        
        if valid_scores:
            average_weighted_score = sum(valid_scores) / len(valid_scores)
        else:
            average_weighted_score = None
            
        if valid_components:
            average_num_components = sum(valid_components) / len(valid_components)
        else:
            average_num_components = None
        
        if valid_hyp_lengths:
            average_hyp_length = sum(valid_hyp_lengths) / len(valid_hyp_lengths)
        else:
            average_hyp_length = None
        
        print("="*60)
        print("EVALUATION SUMMARY (Legacy Format)")
        print("="*60)
        print(f"Number of results: {len(results)}")
        print(f"Average weighted score: {average_weighted_score:.4f}" if average_weighted_score else "Average weighted score: N/A")
        print(f"Average num GT components: {average_num_components:.2f}" if average_num_components else "Average num GT components: N/A")
        print(f"Average hypothesis length: {average_hyp_length:.1f} words" if average_hyp_length else "Average hypothesis length: N/A")
    else:
        print(f"Error: Path does not exist: {result_path}")


if __name__ == "__main__":
    import sys
    
    # Use command line argument if provided, otherwise use default
    if len(sys.argv) > 1:
        result_path = sys.argv[1]
    else:
        result_path = result_folder
    
    summarize_results(result_path)
