#!/usr/bin/env python3
"""
Script to create a subset of a CSV dataset with K top examples based on a metric.
Supports optional filtering of responses containing certain keywords.
"""

import csv
import argparse
import os
import sys
from typing import List, Dict, Tuple

def load_csv_data(filename: str) -> Tuple[List[Dict], List[str]]:
    """Load CSV data and return rows and column names."""
    data = []
    with open(filename, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        column_names = reader.fieldnames
        for row in reader:
            data.append(row)
    return data, column_names

def get_metric_columns(column_names: List[str]) -> List[str]:
    """Get all columns except prompt and response."""
    excluded = {'prompt', 'response'}
    return [col for col in column_names if col not in excluded]

def filter_response_keywords(row: Dict, keywords: List[str] = None) -> bool:
    """
    Check if the first 10 words of the response contain any of the keywords.
    Returns True if the row should be KEPT (doesn't contain keywords).
    Returns False if the row should be FILTERED OUT (contains keywords).
    """
    if keywords is None:
        keywords = ['sorry', "can't", 'cannot', 'apologize', 'not able', 'unable']
    
    response = row.get('response', '')
    if not response:
        return True  # Keep empty responses
    
    # Get first 10 words (case-insensitive)
    words = response.lower().split()[:10]
    first_10_words = ' '.join(words)
    
    # Check if any keyword appears in the first 10 words
    for keyword in keywords:
        if keyword.lower() in first_10_words:
            return False  # Filter out
    
    return True  # Keep

def sort_and_filter_data(
    data: List[Dict], 
    metric: str, 
    ascending: bool = True,
    filter_keywords: bool = False
) -> List[Dict]:
    """Sort data by metric and optionally filter by keywords."""
    filtered_data = []
    
    for row in data:
        # Skip rows with invalid metric values
        try:
            value = float(row[metric])
        except (ValueError, KeyError):
            continue
        
        # Apply keyword filter if requested
        if filter_keywords:
            if not filter_response_keywords(row):
                continue  # Filter out this row
        
        filtered_data.append((value, row))
    
    # Sort by metric value
    filtered_data.sort(key=lambda x: x[0], reverse=not ascending)
    
    return [row for _, row in filtered_data]

def write_csv_output(data: List[Dict], column_names: List[str], output_filename: str):
    """Write data to CSV file."""
    with open(output_filename, 'w', encoding='utf-8', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=column_names)
        writer.writeheader()
        writer.writerows(data)

def main():
    parser = argparse.ArgumentParser(
        description='Create a subset of CSV dataset with top K examples based on a metric',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Get top 100 examples by metric (ascending)
  python create_subset.py input.csv -m metric -k 100 -o output.csv

  # Get top 50 examples by loss_normal (descending) with keyword filtering
  python create_subset.py input.csv -m loss_normal -k 50 -d -f -o output.csv

  # Get bottom 200 examples (ascending) without filtering
  python create_subset.py input.csv -m metric -k 200 -a -o output.csv
        """
    )
    
    parser.add_argument('input_file', help='Input CSV file')
    parser.add_argument('-m', '--metric', required=True, 
                       help='Metric/field name to sort by')
    parser.add_argument('-k', '--top-k', type=int, required=True,
                       help='Number of top examples to select (K)')
    parser.add_argument('-o', '--output', required=True,
                       help='Output CSV file path')
    parser.add_argument('-a', '--ascending', action='store_true', default=False,
                       help='Sort in ascending order (default: descending)')
    parser.add_argument('-d', '--descending', action='store_true', default=False,
                       help='Sort in descending order (default if neither -a nor -d is specified)')
    parser.add_argument('-f', '--filter-keywords', action='store_true',
                       help='Filter out examples where first 10 words of response contain "sorry", "can\'t", "cannot", "apologize", "not able", or "unable"')
    
    args = parser.parse_args()
    
    # Validate input file exists
    if not os.path.exists(args.input_file):
        print(f"Error: Input file '{args.input_file}' not found.", file=sys.stderr)
        sys.exit(1)
    
    # Determine sort order
    if args.ascending and args.descending:
        print("Error: Cannot specify both --ascending and --descending", file=sys.stderr)
        sys.exit(1)
    elif args.ascending:
        ascending = True
    elif args.descending:
        ascending = False
    else:
        # Default to descending if neither is specified
        ascending = False
    
    # Load data
    print(f"Loading data from {args.input_file}...")
    try:
        data, column_names = load_csv_data(args.input_file)
        print(f"Loaded {len(data)} rows")
    except Exception as e:
        print(f"Error loading file: {e}", file=sys.stderr)
        sys.exit(1)
    
    # Validate metric exists
    available_metrics = get_metric_columns(column_names)
    if args.metric not in column_names:
        print(f"Error: Metric '{args.metric}' not found in CSV.", file=sys.stderr)
        print(f"Available columns: {', '.join(column_names)}", file=sys.stderr)
        sys.exit(1)
    
    if args.metric not in available_metrics:
        print(f"Warning: Metric '{args.metric}' is 'prompt' or 'response'. Are you sure?")
        response = input("Continue anyway? [y/N]: ").strip().lower()
        if response != 'y':
            sys.exit(0)
    
    # Sort and filter
    print(f"Sorting by '{args.metric}' ({'ascending' if ascending else 'descending'})...")
    if args.filter_keywords:
        print("Filtering out responses with keywords 'sorry', 'can't', 'cannot', 'apologize', 'not able', or 'unable' in first 10 words...")
    
    sorted_data = sort_and_filter_data(data, args.metric, ascending, args.filter_keywords)
    
    if not sorted_data:
        print("Error: No valid data found after filtering.", file=sys.stderr)
        sys.exit(1)
    
    print(f"Found {len(sorted_data)} valid examples after filtering")
    
    # Select top K
    k = args.top_k
    if k > len(sorted_data):
        print(f"Warning: Requested K={k} examples, but only {len(sorted_data)} available.")
        print(f"Using all {len(sorted_data)} examples.")
        k = len(sorted_data)
    
    subset = sorted_data[:k]
    print(f"Selecting top {k} examples...")
    
    # Write output
    print(f"Writing output to {args.output}...")
    try:
        write_csv_output(subset, column_names, args.output)
        print(f"Successfully created subset with {len(subset)} examples!")
        print(f"Output saved to: {args.output}")
    except Exception as e:
        print(f"Error writing output file: {e}", file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
