"""
Batch code execution with Ray parallelization.
"""
import os
import gc
import time
import argparse

import ray
import pandas as pd
from tqdm import tqdm

from execution import check_correctness
from utils import stream_jsonl, find_parquet_files, save_parquet


@ray.remote
class ProgressTracker:
    """Track progress across Ray workers."""

    def __init__(self):
        self.total = 0
        self.completed = 0

    def register(self, count: int):
        self.total += count

    def update(self, count: int = 1):
        self.completed += count

    def status(self):
        return self.completed, self.total


@ray.remote(num_cpus=1)
def process_file(
    input_path: str,
    output_path: str,
    benchmark_path: str,
    tracker
) -> str:
    """Process single parquet file."""
    benchmark = {t["task_id"]: t for t in stream_jsonl(benchmark_path)}

    try:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df = pd.read_parquet(input_path)

        ray.get(tracker.register.remote(len(df)))

        results = []
        for _, row in df.iterrows():
            task = benchmark[row["task_id"]]

            passed = []
            for code in row["solutions"]:
                program = (
                    f"{task['prompt']}\n"
                    f"{task['starter_code']}\n"
                    f"        pass\n\n"
                    f"{code}\n"
                    f"{task['test']}\n"
                    f"check({task['entry_point']})"
                )
                if check_correctness(0, 0, program, timeout=3)["passed"]:
                    passed.append(code)

            results.append(passed)
            tracker.update.remote(1)

        df["passed"] = results
        save_parquet(df, output_path)
        return None

    except Exception as e:
        return f"Error processing {input_path}: {e}"


def main():
    parser = argparse.ArgumentParser(description="Batch code execution")
    parser.add_argument("--input", required=True, help="Input directory")
    parser.add_argument("--output", required=True, help="Output directory")
    parser.add_argument("--benchmark", required=True, help="Benchmark JSONL")
    parser.add_argument("--cpus", type=int, default=40, help="Number of CPUs")
    args = parser.parse_args()

    os.makedirs(args.output, exist_ok=True)

    files = find_parquet_files(args.input)
    print(f"Found {len(files)} parquet files")

    ray.init(ignore_reinit_error=True, num_cpus=args.cpus)
    tracker = ProgressTracker.remote()

    tasks = [
        process_file.remote(
            os.path.join(args.input, f),
            os.path.join(args.output, f),
            args.benchmark,
            tracker
        )
        for f in files
    ]

    time.sleep(2)
    _, total = ray.get(tracker.status.remote())

    errors = []
    pbar = tqdm(total=total, desc="Processing")
    remaining = len(files)

    while remaining > 0:
        done, tasks = ray.wait(tasks, num_returns=1, timeout=0.5)
        completed, total = ray.get(tracker.status.remote())
        pbar.n = completed
        pbar.total = total
        pbar.refresh()

        for task in done:
            result = ray.get(task)
            if result:
                errors.append(result)
            remaining -= 1

        gc.collect()

    pbar.close()
    ray.shutdown()

    if errors:
        print(f"\nFailed: {len(errors)} files")
        for e in errors:
            print(f"  - {e}")
    else:
        print("\nAll files processed successfully")


if __name__ == "__main__":
    main()
