import os
from pathlib import Path
import sys

sys.path.append(
    os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "inference",
    )
)

import json
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

from jsonargparse import ArgumentParser
from redacted import (
    BooleanProperty,
    ChatModel,
    ChatRequest,
    JsonCache,
    JsonSchema,
    Message,
    ModelSupports,
    ObjectProperty,
    OptionalProperty,
    ResponseType,
    Role,
    StringProperty,
    Client,
    Models,
)
from pydantic import BaseModel
from tqdm import tqdm
from utilities import get_root

system = """
You are an expert at jq: a utility for querying and editing JSON files.
Your task is to annotate benchmark problems about jq.

You are given a description of a task (and optionally example input and output) that a user wants to perform using jq.

Which of the following properties invalidate the benchmark?
 - Does the task mention variables as additional input to the jq command, either as inputs or to store results to?
 - Does the task mention streaming of large quantities of data?
 - Does the task consider multiple input or output files?
 - If input and output is given, is any information missing from the task description to make that transformation?

Additionally, if you spot any other related properties about the execution environment, mention them in the `other_task` or `other_environment` sections.
- Only add `other_task` if you did not choose any task in question (1).
- Only add `other_environment` if you answered `False` to all questions (2-5).

Return your result as a JSON object.
""".strip()


schema = JsonSchema(
    name="Answers",
    schema_=ObjectProperty(
        properties={
            "invalid": ObjectProperty(
                properties={
                    "variables": BooleanProperty(),
                    "stream": BooleanProperty(),
                    "multiple": BooleanProperty(),
                    "missing": BooleanProperty(),
                }
            ),
            "other_task": OptionalProperty(property=StringProperty()),
            "other_environment": OptionalProperty(property=StringProperty()),
        }
    ),
)


class Invalid(BaseModel):
    variables: bool
    stream: bool
    multiple: bool
    missing: bool


class Answers(BaseModel):
    invalid: Invalid
    other_task: Optional[str] = None
    other_environment: Optional[str] = None


def annotate(utterance: str, data: list, model):
    """Process a single post and return the result with identifier"""
    user = utterance
    if len(data) > 0:
        user += "\n\nThe following are potential `input` (i) or `input` -> `output` examples (io):"
        for e in data:
            if "input" in e and "output" in e:
                user += f"\n- (io): `{json.dumps(e['input'])}` -> `{json.dumps(e['output'])}`"
            elif "input" in e:
                user += f"\n- (i): `{json.dumps(e['input'])}`"
    response = model.chat(
        [
            Message(role=Role.System, content=system),
            Message(role=Role.User, content=user),
        ],
        ChatRequest(temperature=0.0, response_format=ResponseType.JsonSchema(schema)),
    )
    return Answers.model_validate(response.json)


if __name__ == "__main__":

    # fmt: off
    parser = ArgumentParser()
    parser.add_argument("-p", "--page", type=int, default=1)
    parser.add_argument("-m", "--model", type=str, default="Gpt41", choices=Models.names(ModelSupports.Functions))
    parser.add_argument("-c", "--cache", type=str, default="auto")
    parser.add_argument("--select", type=str, default=None)
    parser.add_argument("--force", action="store_true", default=False)
    args = parser.parse_args()
    # fmt: on

    debugging = args.select is not None
    root = get_root()
    root_so = root / "data" / "stackoverflow_"

    # load cache
    if args.cache == "auto":
        args.cache = (
            Path("E:/Papers/jqBench/Caches")
            / "annotate"
            / args.model.lower()
            / f"page-{args.page:02d}.json"
        )

    # get input and output file
    file_i = root_so / "1_extracted" / f"page-{args.page:02d}.json"
    file_o = root_so / "2_annotated" / f"page-{args.page:02d}.json"

    # load model
    model_spec = Models.by_name(args.model)
    model_cache = JsonCache(args.cache, autosave=8) if args.cache else None
    model = ChatModel(model=model_spec, client=Client(), cache=model_cache)

    # load and filter to-do
    if file_o.exists() and not args.force and not debugging:
        data_target = json.load(open(file_o, encoding="utf-8"))
        done = set(t["identifier"] for t in data_target)
    else:
        data_target = list()
        done = set()
        file_o.parent.mkdir(parents=True, exist_ok=True)
    todo = json.load(open(file_i, encoding="utf-8"))
    if done:
        todo = [p for p in todo if p["identifier"] not in done]
    if debugging:
        todo = [p for p in todo if str(p["identifier"]) == args.select]

    if len(todo) == 0 or all(e.get("utterance", None) is None for e in todo):
        quit(f"Page {args.page} is all done.")

    # process
    stat = Counter()
    task = Counter()

    def process_entry(entry, model):
        """Process a single entry"""
        if (utterance := entry.get("utterance", None)) is None:
            return None
        annotated = annotate(utterance, entry.get("data", list()), model).model_dump(
            mode="json"
        )
        entry["annotated"] = annotated
        return entry

    with ThreadPoolExecutor(max_workers=8) as executor:
        futures = {
            executor.submit(process_entry, e, model): e
            for e in todo
            if e.get("utterance", None) is not None
        }
        pbar = tqdm(total=len(futures), desc=f"Page {args.page}", disable=debugging)
        for future in as_completed(futures):
            try:
                e = future.result()
                data_target.append(e)
                if not debugging:
                    with open(file_o, "w", encoding="utf-8") as f:
                        o = json.dumps(
                            sorted(data_target, key=lambda x: x["identifier"]), indent=2
                        )
                        f.write(o)
                for k, v in e["annotated"].items():
                    if k == "invalid":
                        for reason, invalid in v.items():
                            stat[reason] += invalid
                    if isinstance(v, str):
                        if k == "other_task":
                            stat["other_task"] += 1
                        elif k == "other_environment":
                            stat["other_environment"] += 1
                    if isinstance(v, list):
                        for elem in v:
                            task[elem] += 1
                pbar.update(1)
                pbar.set_postfix(**stat, **task)
            except Exception:
                entry = futures[future]

    if model_cache:
        model_cache.save()
