import json
import os
from pathlib import Path
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed

sys.path.append(
    os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "inference",
    )
)

from jsonargparse import ArgumentParser  # type: ignore
from redacted import (
    ChatModel,
    ChatRequest,
    JsonCache,
    Message,
    ModelSupports,
    ResponseType,
    Role,
    Client,
    Models,
)
from tqdm import tqdm
from utilities import get_root


def format_question(post: dict) -> str:
    formatted = {
        "title": post.get("title"),
        "question": post.get("body_markdown"),
        "answers": [a.get("body_markdown") for a in post.get("answers", [])],
    }
    return json.dumps(formatted, indent=2, ensure_ascii=False)


SYSTEM = """
You are given a raw forum question and its answers, all formatted in JSON.
Your task is to extract a single benchmark that reflects the core problem and solution discussed in the thread, specifically related to querying data using `jq`.

Return your result as a JSON object that conforms to the following TypeScript interface:
```typescript
interface Benchmark {
    context: string[];       // Direct quotes from the forum post and answers that capture the key problem and solution.
                             // - Do not paraphrase or summarize.
    utterance: string;       // A concise, precise description of the user's intent.
                             // - Avoid repeating \"JSON\" or \"jq\".
                             // - Include specific values or conditions mentioned in the query.
                             // - Do not describe how to solve the problem—just what the user wants.
    expressions: string[];   // All `jq` expressions that satisfy the user's request.
    data?: Data[];           // Optional: input JSON and expected output if available.
}

interface Data {
   input: any;              // The input JSON data.
   output?: any;            // If mentioned, the expected result of the query.
}
```

⚠️ Ensure that:
- The benchmark reflects a real-world use case.
- The utterance is specific and testable.
- The expressions are valid and complete.

⚠️️️ If there is no benchmark available, return an empty JSON object.
""".strip()


def process_post(post: dict, model: ChatModel, request: ChatRequest) -> dict:
    """Process a single post and return structured benchmark with identifier."""
    response = model.chat(
        [
            Message(role=Role.System, content=SYSTEM),
            Message(role=Role.User, content=format_question(post)),
        ],
        request,
    )
    response_json = response.json or {}
    response_json["identifier"] = post.get("question_id")
    return response_json


if __name__ == "__main__":

    # fmt: off
    parser = ArgumentParser()
    parser.add_argument("-p", "--page", type=int, default=0)
    parser.add_argument("-m", "--model", type=str, default="Gpt41", choices=Models.names(ModelSupports.Chat))
    parser.add_argument("-c", "--cache", type=str, default="auto")
    parser.add_argument("--force", action="store_true", default=False)
    parser.add_argument("--workers", type=int, default=8)
    args = parser.parse_args()
    # fmt: on

    root = get_root()
    root_so = root / "data" / "stackoverflow_"

    if args.cache == "auto":
        args.cache = (
            Path("E:/Papers/jqBench/Caches")
            / "extract"
            / args.model.lower()
            / "cache.json"
        )

    # load model
    model_spec = Models.by_name(args.model)
    model_cache = JsonCache(args.cache, autosave=10) if args.cache else None
    model = ChatModel(model=model_spec, client=Client(), cache=model_cache)
    model_request = ChatRequest(response_format=ResponseType.JsonObject())

    # paths
    raw_dir = root_so / "0_raw"
    out_dir = root_so / "1_extracted"
    raw_files = sorted(raw_dir.glob("*.json"))
    if args.page > 0:
        raw_files = [f for f in raw_files if f"page-{args.page:02d}" in f.stem]

    for file in tqdm(raw_files, desc="files", position=0, leave=True):
        file_target = out_dir / file.name
        if file_target.exists() and not args.force:
            data_target = json.load(open(file_target, encoding="utf-8"))
        else:
            data_target = list()
            file_target.parent.mkdir(parents=True, exist_ok=True)
        data = json.load(open(file, encoding="utf-8"))
        todo = [
            post
            for post in data
            if not any(
                post.get("question_id") == p.get("identifier") for p in data_target
            )
        ]
        if not todo:
            continue
        with ThreadPoolExecutor(max_workers=args.workers) as executor:
            futures = {
                executor.submit(process_post, post, model, model_request): post
                for post in todo
            }
            for future in tqdm(
                as_completed(futures), total=len(futures), position=1, leave=False
            ):
                data_target.append(future.result())
        with open(file_target, "w", encoding="utf-8") as f:
            o = json.dumps(data_target, indent=2, ensure_ascii=True)
            f.write(o)
        if model.cache:
            model.cache.save()

    if model.cache:
        model.cache.autosave = 1
