"""
Filter valid Python solutions from code snippets.
"""
import os
import ast
import gc
import argparse
from typing import Optional

import ray
import pandas as pd
from tqdm import tqdm

from utils import find_parquet_files, save_parquet


def is_valid_python(code: str) -> bool:
    """Check if code is valid Python with a solution class/function."""
    if "solution" not in code.lower():
        return False
    try:
        ast.parse(code)
        return True
    except Exception:
        return False


@ray.remote(num_cpus=1)
def process_file(input_path: str, output_path: str) -> Optional[str]:
    """Process single parquet file."""
    try:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df = pd.read_parquet(input_path)

        df["python"] = df["solutions"].apply(
            lambda sols: [s for s in sols if is_valid_python(s)]
        )
        save_parquet(df, output_path)
        return None

    except Exception as e:
        return f"Error: {input_path}: {e}"


def main():
    parser = argparse.ArgumentParser(description="Filter Python solutions")
    parser.add_argument("--input", required=True, help="Input directory")
    parser.add_argument("--output", required=True, help="Output directory")
    parser.add_argument("--cpus", type=int, default=64, help="Number of CPUs")
    args = parser.parse_args()

    os.makedirs(args.output, exist_ok=True)

    files = find_parquet_files(args.input)
    print(f"Found {len(files)} files")

    ray.init(ignore_reinit_error=True, num_cpus=args.cpus)

    tasks = [
        process_file.remote(
            os.path.join(args.input, f),
            os.path.join(args.output, f)
        )
        for f in files
    ]

    errors = []
    with tqdm(total=len(files), desc="Filtering") as pbar:
        remaining = len(files)
        while remaining > 0:
            done, tasks = ray.wait(tasks, num_returns=1)
            for task in done:
                result = ray.get(task)
                if result:
                    errors.append(result)
            remaining -= len(done)
            pbar.update(len(done))
            gc.collect()

    ray.shutdown()

    if errors:
        print(f"\nFailed: {len(errors)} files")
        for e in errors:
            print(f"  - {e}")
    else:
        print("\nDone")


if __name__ == "__main__":
    main()
