"""
Extract code solutions from markdown responses.
"""
import os
import re
import gc
import argparse
from typing import List, Optional

import ray
import pandas as pd
from tqdm import tqdm

from utils import find_parquet_files, save_parquet


def extract_code_blocks(
    text: str,
    language: Optional[str] = None,
    mode: str = "all"
) -> List[str]:
    """
    Extract code blocks from markdown text.

    Args:
        text: Markdown text
        language: Optional language filter (e.g., 'python')
        mode: 'first', 'last', or 'all'

    Returns:
        List of extracted code strings
    """
    text = text.replace("\\n", "\n").replace("\\r", "\r").replace("\\t", "\t")
    text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)

    if language:
        pattern = rf"^```{re.escape(language)}[ \t]*\r?\n(.*?)\r?\n```"
        flags = re.DOTALL | re.MULTILINE | re.IGNORECASE
    else:
        pattern = r"^```[^\r\n]*\r?\n(.*?)\r?\n```"
        flags = re.DOTALL | re.MULTILINE

    matches = re.findall(pattern, text, flags)

    if not matches:
        return []
    if mode == "first":
        return [matches[0]]
    if mode == "last":
        return [matches[-1]]
    return matches


@ray.remote(num_cpus=1)
def process_file(input_path: str, output_path: str) -> Optional[str]:
    """Process single parquet file."""
    try:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df = pd.read_parquet(input_path)

        df["solutions"] = df["response_content"].apply(extract_code_blocks)
        save_parquet(df, output_path)
        return None

    except Exception as e:
        return f"Error: {input_path}: {e}"


def main():
    parser = argparse.ArgumentParser(description="Extract code from responses")
    parser.add_argument("--input", required=True, help="Input directory")
    parser.add_argument("--output", required=True, help="Output directory")
    parser.add_argument("--cpus", type=int, default=64, help="Number of CPUs")
    args = parser.parse_args()

    os.makedirs(args.output, exist_ok=True)

    files = find_parquet_files(args.input)
    print(f"Found {len(files)} files")

    ray.init(ignore_reinit_error=True, num_cpus=args.cpus)

    tasks = [
        process_file.remote(
            os.path.join(args.input, f),
            os.path.join(args.output, f)
        )
        for f in files
    ]

    errors = []
    with tqdm(total=len(files), desc="Extracting") as pbar:
        remaining = len(files)
        while remaining > 0:
            done, tasks = ray.wait(tasks, num_returns=1)
            for task in done:
                result = ray.get(task)
                if result:
                    errors.append(result)
            remaining -= len(done)
            pbar.update(len(done))
            gc.collect()

    ray.shutdown()

    if errors:
        print(f"\nFailed: {len(errors)} files")
        for e in errors:
            print(f"  - {e}")
    else:
        print("\nDone")


if __name__ == "__main__":
    main()
