import json
import os
import argparse
from collections import defaultdict

def calculate_wiscore(consistency, realism, aesthetic_quality):
    """Calculates the WiScore based on given components."""
    return (0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality) / 2

# Define expected prompt ID ranges at a global level for easy access
EXPECTED_PROMPT_RANGES = {
    "culture": range(1, 401),
    "space-time": range(401, 701), # Covers both TIME (401-567) and SPACE (568-700)
    "science": range(701, 1001), # Covers BIOLOGY (701-800), PHYSICS (801-900), CHEMISTRY (901-1000)
    "all": range(1, 1001) # Full range for combined evaluation
}

def process_jsonl_file_segment(file_path, category_arg=None):
    """
    Processes a segment of a JSONL file, collecting scores and present prompt_ids.
    Performs prompt_id validation if a specific category_arg is provided for a single file.
    Returns collected data or None if critical errors or missing prompt_ids (for single-file validation).
    """
    segment_scores = defaultdict(list)
    segment_present_prompt_ids = set()
    
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' not found.")
        return None

    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line_num, line in enumerate(file, 1):
                try:
                    data = json.loads(line)
                    
                    prompt_id = data.get('prompt_id')
                    if prompt_id is None:
                        print(f"Warning: File '{file_path}', Line {line_num}: Missing 'prompt_id'. Skipping this line.")
                        continue
                    
                    if not isinstance(prompt_id, int):
                        print(f"Warning: File '{file_path}', Line {line_num}: 'prompt_id' is not an integer. Skipping this line.")
                        continue

                    segment_present_prompt_ids.add(prompt_id)
                    
                    consistency = data.get('consistency')
                    realism = data.get('realism')
                    aesthetic_quality = data.get('aesthetic_quality')

                    if not all(isinstance(val, (int, float)) for val in [consistency, realism, aesthetic_quality]):
                        print(f"Warning: File '{file_path}', Line {line_num}: One or more score values are not numeric. Skipping this line for category calculation.")
                        continue
                    
                    wiscore = calculate_wiscore(consistency, realism, aesthetic_quality)

                    # Determine category based on prompt_id
                    if 1 <= prompt_id <= 400:
                        segment_scores['CULTURE'].append(wiscore)
                    elif 401 <= prompt_id <= 567:
                        segment_scores['TIME'].append(wiscore)
                    elif 568 <= prompt_id <= 700:
                        segment_scores['SPACE'].append(wiscore)
                    elif 701 <= prompt_id <= 800:
                        segment_scores['BIOLOGY'].append(wiscore)
                    elif 801 <= prompt_id <= 900:
                        segment_scores['PHYSICS'].append(wiscore)
                    elif 901 <= prompt_id <= 1000:
                        segment_scores['CHEMISTRY'].append(wiscore)
                    else:
                        print(f"Warning: File '{file_path}', Line {line_num}: prompt_id {prompt_id} is outside defined categories. Skipping this line.")
                        continue

                except json.JSONDecodeError:
                    print(f"Warning: File '{file_path}', Line {line_num}: Invalid JSON format. Skipping this line.")
                except KeyError as e:
                    print(f"Warning: File '{file_path}', Line {line_num}: Missing expected key '{e}'. Skipping this line.")
    except Exception as e:
        print(f"Error reading file '{file_path}': {e}")
        return None
    
    # --- Single-file prompt_id validation logic ---
    if category_arg and category_arg != 'all' and category_arg in EXPECTED_PROMPT_RANGES:
        expected_ids_for_this_category = set(EXPECTED_PROMPT_RANGES[category_arg])
        missing_ids_in_segment = expected_ids_for_this_category - segment_present_prompt_ids

        if missing_ids_in_segment:
            print(f"Error: File '{file_path}': When evaluating as '--category {category_arg}', "
                  f"missing the following prompt_ids: {sorted(list(missing_ids_in_segment))}")
            return None # Return None if required prompt_ids are missing for a specific category file
    
    return {
        'scores': segment_scores,
        'present_prompt_ids': segment_present_prompt_ids,
        'file_path': file_path
    }

def main():
    parser = argparse.ArgumentParser(
        description="Evaluate JSONL files for model performance, categorizing scores by prompt_id."
    )
    parser.add_argument(
        'files',
        metavar='FILE',
        nargs='+',  # Accepts one or more file paths
        help="Path(s) to the JSONL file(s) to be evaluated (e.g., cultural_common_sense_ModelName_scores.jsonl)"
    )
    parser.add_argument(
        '--category',
        type=str,
        choices=['culture', 'space-time', 'science', 'all'],
        default='all',
        help="Specify the category of the JSONL file(s) for specific prompt_id validation. Choose from 'culture', 'space-time', 'science', or 'all' (default). If evaluating a single category file, use the corresponding category."
    )
    
    args = parser.parse_args()
    
    all_raw_results = []
    
    # Process each file to collect raw scores and prompt IDs
    for file_path in args.files:
        print(f"\n--- Processing file: {file_path} ---")
        # Pass the category argument to process_jsonl_file_segment
        # This enables single-file validation logic
        results = process_jsonl_file_segment(file_path, args.category if len(args.files) == 1 else None)
        if results:
            all_raw_results.append(results)
        else:
            print(f"Could not process '{file_path}'. Please check previous warnings/errors.")

    if not all_raw_results:
        print("No valid data processed from any of the provided files. Exiting.")
        return # Exit if no files were successfully processed

    # Aggregate data across all successful files
    aggregated_scores = defaultdict(list)
    combined_present_prompt_ids = set()
    final_file_reports = {} # To store calculated averages/counts per file for individual display

    for file_data in all_raw_results:
        file_path = file_data['file_path']
        combined_present_prompt_ids.update(file_data['present_prompt_ids'])
        
        # Calculate scores for this individual file (for individual file report)
        current_file_avg_scores = {}
        current_file_num_samples = {}
        detected_categories_in_file = []

        for category, scores_list in file_data['scores'].items():
            aggregated_scores[category].extend(scores_list) # Aggregate for overall score later
            
            if scores_list: # Only add to individual file report if samples exist
                current_file_avg_scores[category] = sum(scores_list) / len(scores_list)
                current_file_num_samples[category] = len(scores_list)
                detected_categories_in_file.append(category)

        final_file_reports[file_path] = {
            'average': current_file_avg_scores,
            'num_processed_samples': current_file_num_samples,
            'detected_categories': detected_categories_in_file
        }

    # --- Step 1: Validate Prompt IDs for 'all' category scenario ---
    # This check happens only when --category all is explicitly chosen or is the default for multiple files.
    # Single-file specific category validation happens inside process_jsonl_file_segment.
    if args.category == 'all':
        expected_prompt_ids_for_all = set(EXPECTED_PROMPT_RANGES['all'])
        missing_prompt_ids_in_combined = expected_prompt_ids_for_all - combined_present_prompt_ids

        if missing_prompt_ids_in_combined:
            print(f"\nError: When '--category all' is specified, the combined files are missing the following prompt_ids:")
            print(f"Missing IDs: {sorted(list(missing_prompt_ids_in_combined))}")
            print("\nAborting overall evaluation due to incomplete data.")
            return # Exit if combined prompt IDs are missing when 'all' is expected

    # --- Step 2: Display individual file reports ---
    print("\n" + "="*50)
    print("                 Individual File Reports")
    print("="*50 + "\n")

    ordered_categories = ['CULTURE', 'TIME', 'SPACE', 'BIOLOGY', 'PHYSICS', 'CHEMISTRY']

    for file_path, file_data in final_file_reports.items():
        print(f"--- Evaluation Results for File: {file_path} ---")
        
        categories_to_print = sorted([cat for cat in ordered_categories if cat in file_data['detected_categories']],
                                     key=lambda x: ordered_categories.index(x))

        if not categories_to_print:
            print("  No scores found for any defined categories in this file.")
        else:
            for category in categories_to_print:
                avg_score = file_data['average'].get(category, 0)
                sample_count = file_data['num_processed_samples'].get(category, 0)
                print(f"  Category: {category}")
                print(f"    Average WiScore: {avg_score:.2f}")
                print(f"    Number of samples: {sample_count}\n")
        print("-" * (len(file_path) + 30) + "\n")

    # --- Step 3: Calculate and Display Overall Summary (if applicable) ---
    print("\n" + "="*50)
    print("                 Overall Evaluation Summary")
    print("="*50 + "\n")

    # Calculate overall averages from aggregated scores
    overall_avg_scores = {
        category: sum(scores) / len(scores) if len(scores) > 0 else 0
        for category, scores in aggregated_scores.items()
    }
    overall_num_samples = {
        category: len(scores)
        for category, scores in aggregated_scores.items()
    }

    # Print overall category scores (only for categories that have samples)
    overall_categories_to_print = sorted([cat for cat in ordered_categories if overall_num_samples.get(cat, 0) > 0],
                                          key=lambda x: ordered_categories.index(x))
    
    if not overall_categories_to_print and args.category != 'all':
        print("No valid scores found for any categories in the aggregated data.")
    else:
        print("Aggregated Category Scores:")
        for category in overall_categories_to_print:
            print(f"  Category: {category}")
            print(f"    Average WiScore: {overall_avg_scores.get(category, 0):.2f}")
            print(f"    Number of samples: {overall_num_samples.get(category, 0)}\n")

    # Calculate and print Overall WiScore if '--category all' was specified and all categories have samples
    all_categories_have_overall_samples = all(overall_num_samples.get(cat, 0) > 0 for cat in ordered_categories)
    
    if args.category == 'all' and all_categories_have_overall_samples:
        cultural_score = overall_avg_scores.get('CULTURE', 0)
        time_score = overall_avg_scores.get('TIME', 0)
        space_score = overall_avg_scores.get('SPACE', 0)
        biology_score = overall_avg_scores.get('BIOLOGY', 0)
        physics_score = overall_avg_scores.get('PHYSICS', 0)
        chemistry_score = overall_avg_scores.get('CHEMISTRY', 0)

        overall_wiscore = (0.4 * cultural_score + 0.167 * time_score + 0.133 * space_score +
                           0.1 * biology_score + 0.1 * physics_score + 0.1 * chemistry_score)
        
        print("\n--- Overall WiScore Across All Categories ---")
        print(f"Overall WiScore: {overall_wiscore:.2f}")
        print("Cultural\tTime\tSpace\tBiology\tPhysics\tChemistry\tOverall")
        print(f"{cultural_score:.2f}\t\t{time_score:.2f}\t{space_score:.2f}\t{biology_score:.2f}\t{physics_score:.2f}\t{chemistry_score:.2f}\t\t{overall_wiscore:.2f}")
    elif args.category == 'all' and not all_categories_have_overall_samples:
        print("\nOverall WiScore cannot be calculated: Not all categories have samples in the aggregated data when '--category all' is specified.")
    else:
        print(f"\nOverall WiScore calculation skipped. To calculate overall score, use '--category all' and provide files covering all prompt IDs.")


if __name__ == "__main__":
    main()
