import pandas as pd
import os
from pathlib import Path

def filter_math_dataset(input_dir, output_dir, filename, level):
    # Define paths
    input_path = input_dir / filename
    output_path = output_dir / filename
    
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Reading dataset from: {input_path}")
    
    # Read the parquet file
    try:
        df = pd.read_parquet(input_path)
        print(f"Original dataset shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")
        
        # Check if 'level' column exists
        if 'level' not in df.columns:
            print("Available columns:", list(df.columns))
            raise ValueError("'level' column not found in the dataset")
        
        # Display level distribution
        print("\nLevel distribution:")
        print(df['level'].value_counts().sort_index())
        
        # Extract numeric level from string format (e.g., "Level 1" -> 1)
        def extract_level_number(level_str):
            if isinstance(level_str, str):
                if level_str.startswith("Level "):
                    try:
                        return int(level_str.split()[1])
                    except (IndexError, ValueError):
                        return None
            return None
        
        # Create numeric level column
        df['level_numeric'] = df['level'].apply(extract_level_number)
        
        # Filter for level >= 3 (excluding None values which are "Level ?" entries)
        filtered_df = df[(df['level_numeric'].notna()) & (df['level_numeric'] >= level)].copy()
        
        # Remove the temporary numeric column
        filtered_df = filtered_df.drop('level_numeric', axis=1)
        
        print(f"\nFiltered dataset shape: {filtered_df.shape}")
        print(f"Samples removed: {len(df) - len(filtered_df)}")
        
        # Display filtered level distribution
        print("\nFiltered level distribution:")
        print(filtered_df['level'].value_counts().sort_index())
        
        # Save the filtered dataset
        filtered_df.to_parquet(output_path, index=False)
        print(f"\nFiltered dataset saved to: {output_path}")
        
        return filtered_df
        
    except FileNotFoundError:
        print(f"Error: File not found at {input_path}")
        print("Please make sure the file exists and the path is correct.")
        return None
    except Exception as e:
        print(f"Error processing dataset: {e}")
        return None

if __name__ == "__main__":
    input_dir = Path.home() / "data" / "math"
    output_dir = Path.home() / "data" / "math-hard"
    level = 3
    for filename in ["train.parquet", "test.parquet"]:
        filtered_data = filter_math_dataset(input_dir, output_dir, filename, level)
