#!/usr/bin/env python3
"""
Script for scoring LLM responses against a dataset.

Usage:
    python score_responses.py <dataset_path> <responses_path> <scoring_function_name> <output_csv>
"""

import argparse
import csv
import importlib.util
import os
import shelve
import sys
from typing import Dict, List, Any, Optional
import numpy as np
from datasets import load_from_disk


def load_scoring_function(function_name: str):
    """
    Load a scoring function from scoring_functions.py in the current directory.
    
    Args:
        function_name: Name of the function to load
        
    Returns:
        The scoring function
    """
    script_dir = os.path.dirname(os.path.abspath(__file__))
    module_path = os.path.join(script_dir, "scoring_functions.py")
    
    if not os.path.exists(module_path):
        raise FileNotFoundError(f"scoring_functions.py not found in {script_dir}")
    
    spec = importlib.util.spec_from_file_location("scoring_functions", module_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules["scoring_functions"] = module
    spec.loader.exec_module(module)
    
    if not hasattr(module, function_name):
        raise AttributeError(f"Function '{function_name}' not found in scoring_functions.py")
    
    return getattr(module, function_name)


def score_responses(dataset, responses: Dict[str, List], scoring_function) -> Dict[str, float]:
    """
    Score all responses using the specified scoring function.
    
    Args:
        dataset: The loaded dataset
        responses: Dictionary mapping keys to response lists
        scoring_function: Function to score each response
        
    Returns:
        Dictionary mapping keys to average scores
    """
    results = {}
    
    for key, response_list in responses.items():
        if len(response_list) != len(dataset):
            print(f"Warning: Number of responses for key '{key}' ({len(response_list)}) "
                  f"doesn't match dataset size ({len(dataset)})")
            continue
        
        scores = []
        for i, (response, observation) in enumerate(zip(response_list, dataset)):
            try:
                score = scoring_function(response, observation)
                if score is not None:
                    scores.append(float(score))
            except Exception as e:
                print(f"Error scoring response {i} for key '{key}': {e}")
                continue
        
        # Calculate average score, defaulting to 0 if all scores are None
        if scores:
            avg_score = np.mean(scores)
        else:
            avg_score = 0.0
            
        results[key] = avg_score
    
    return results


def write_results_to_csv(results: Dict[str, float], output_path: str):
    """
    Write the scoring results to a CSV file.
    
    Args:
        results: Dictionary mapping keys to average scores
        output_path: Path to write the CSV file
    """
    # Create parent directories if they don't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['key', 'average_score'])
        
        for key, score in sorted(results.items()):
            writer.writerow([key, score])


def main():
    parser = argparse.ArgumentParser(description="Score LLM responses against a dataset")
    parser.add_argument("dataset_path", help="Path to the dataset (loadable with load_from_disk)")
    parser.add_argument("responses_path", help="Path to the shelve database with responses")
    parser.add_argument("scoring_function", help="Name of the scoring function in scoring_functions.py")
    parser.add_argument("output_csv", help="Path to write the output CSV file")
    
    args = parser.parse_args()
    
    # Load dataset
    print(f"Loading dataset from {args.dataset_path}...")
    dataset = load_from_disk(args.dataset_path)
    print(f"Dataset loaded with {len(dataset)} observations")
    
    # Load responses from shelve
    print(f"Loading responses from {args.responses_path}...")
    with shelve.open(args.responses_path, 'r') as db:
        responses = dict(db)
    print(f"Loaded responses for {len(responses)} keys")
    
    # Load scoring function
    print(f"Loading scoring function '{args.scoring_function}'...")
    scoring_function = load_scoring_function(args.scoring_function)
    
    # Score responses
    print("Scoring responses...")
    results = score_responses(dataset, responses, scoring_function)
    
    # Write results
    print(f"Writing results to {args.output_csv}...")
    write_results_to_csv(results, args.output_csv)
    
    print("Done!")
    
    # Print summary
    print("\nSummary:")
    for key, score in sorted(results.items()):
        print(f"  {key}: {score:.4f}")


if __name__ == "__main__":
    main()
