import os
from pathlib import Path
import sys

sys.path.append(
    os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "inference",
    )
)
import json
import re
from typing import Any

import jq
from jsonargparse import ArgumentParser
from redacted import (
    ChatModel,
    Function,
    JsonCache,
    ModelSupports,
    StringProperty,
    Client,
    Models,
)
from agent import Agent, Environment, Tool, ToolResult
from rich.console import Console
from tqdm import tqdm


def format_question(
    expressions: list[str],
    utterance: str,
    data: list,
) -> str:
    if isinstance(expressions, str):
        expressions = [expressions]
    formatted = f"""
Task: "{utterance}"

The following commands potentially achieve the task:
{'\n'.join(f'- `{e}`' for e in expressions)}
You can (and are encouraged to) change them before running the tests.
                 """.strip()

    examples = list()
    inputs = list()
    for element in data:
        if not isinstance(element, dict):
            continue
        i = element.get("input", None)
        o = element.get("output", None)
        if i is None:
            continue

        if o is None:
            inputs.append(f"- `{json.dumps(i)}`")
        else:
            examples.append(f"- `{json.dumps(i)}` -> `{json.dumps(o)}`")

    if len(examples) > 0:
        formatted += "\n\n"
        formatted += (
            "The following (input -> output) examples can serve as inspiration to create test cases. "
            "Note that these are not guaranteed to be correct.\n"
        )
        formatted += "\n".join(examples)
        formatted += "\nIf applicable, make sure that objects in test cases include all required fields."
    if len(inputs) > 0:
        formatted += "\n\n"
        formatted += (
            "The following inputs can serve as inspiration to create test cases. "
            "Note that these are not guaranteed to be correct.\n"
        )
        formatted += "\n".join(inputs)
        formatted += "\nIf applicable, make sure that objects in test cases include all required fields."
    return formatted


class JsonPathEnvironment(Environment):
    def __init__(self):
        super().__init__(None)
        self.tests: dict[str, dict[str, Any]] = dict()
        self.tests_old: dict[str, list[dict[str, Any]]] = dict()
        self.candidates = list()

    def add_test(self, name: str, i: Any, o: Any) -> None:
        if name in self.tests:
            if name not in self.tests_old:
                self.tests_old[name] = list()
            self.tests_old[name].append(self.tests[name])
        self.tests[name] = {"input": i, "output": o}

    def add_candidate(self, expression: str) -> None:
        if expression not in self.candidates:
            self.candidates.append(expression)


class AddTestTool(Tool[JsonPathEnvironment]):
    def definition(self) -> Function:
        return Function(
            name="add_test",
            description="Add or update a test case for a jq expression.",
            parameters={
                "name": StringProperty(description="A unique name for the test case."),
                "input": StringProperty(
                    description="A serialized string representation of the input JSON."
                ),
                "output": StringProperty(
                    description=(
                        "A serialized string representation of the expected output JSON, following Python's `jq.all(expression, data)` output format."
                    )
                ),
            },
        )

    def execute(self, arguments, environment, *args, **kwargs):
        a_name = arguments.get("name", None)
        a_input = arguments.get("input", None)
        a_output = arguments.get("output", None)
        errors = list()
        try:
            a_input = json.loads(a_input)
        except json.JSONDecodeError as e:
            errors.append(f"Error decoding input JSON: {e}")
        try:
            a_output = json.loads(a_output)
        except json.JSONDecodeError as e:
            errors.append(f"Error decoding output JSON: {e}")
        if errors:
            return ToolResult(values={"errors": errors})
        update = "updated" if a_name in environment.tests else "added"
        environment.add_test(a_name, a_input, a_output)
        return ToolResult(values={"success": True, "update": update})

    def compile(self, result):
        if errors := result.values.get("errors"):
            result = "Test not added.\n"
            for error in errors:
                result += f"- {error}\n"
            return result
        if not result.values.get("success", False) or not result.values.get("update"):
            return "Test not added."
        return f"Test case {result.values.get('update')} successfully."


class RunTestTool(Tool[JsonPathEnvironment]):
    def definition(self) -> Function:
        return Function(
            name="run_tests",
            description="Execute a jq expression on the tests.",
            parameters={
                "jq": StringProperty(
                    description=(
                        "The jq expression `e` to evaluate on the input `i` as `i | jq e` or `jq.all(e, i)`."
                    )
                )
            },
        )

    def execute(
        self,
        arguments: dict[str, Any],
        environment: JsonPathEnvironment,
        *args,
        **kwargs,
    ) -> ToolResult:
        try:
            compiled = jq.compile((expression := arguments.get("jq", None)))
        except Exception as e:
            return ToolResult(
                values={"error": f"Error compiling jq expression '{expression}': {e}"}
            )
        environment.add_candidate(expression)
        result = dict()
        for name, test in environment.tests.items():
            try:
                values = compiled.input_value(test["input"]).all()
                value_dump = json.dumps(values, sort_keys=True)
                if value_dump == json.dumps(test["output"], sort_keys=True):
                    result[name] = "Success"
                    continue
                if value_dump == json.dumps([test["output"]], sort_keys=True):
                    result[name] = (
                        f"Partial success: {value_dump} should be wrapped in an additional array to match `jq.all()` syntax ({json.dumps([test['output']], sort_keys=True)})."
                    )
                    continue
                result[name] = (
                    f"Failed: output `{values}` does not match the expected result ({test['output']})."
                )
            except Exception as e:
                result[name] = f"Error: {e}"
        return ToolResult(values=result)

    def compile(self, result: ToolResult) -> str:
        return json.dumps(result.values, indent=2)


system = """
# Instructions

You are an expert at jq.
You return all jq expressions (wrapped in <jq></jq>) that achieve the intended goal.

You can add or update test cases by using the `add_test` tool.
You can execute an expression on all test cases using the `run_tests` tool.

# Testing

- Make sure to add at least two and at most six tests.
- Make sure to run the tests before making a final suggestion.
- If an expression is correct but the test is incorrect, you need to update the test case to match the correct output.
- Make sure that new test cases follow the same JSON schema as the provided examples.
  - If the input contains an object, make sure that new objects follow the same structure and include all required fields.
  - If the input looks like it will be consistent, you should not test for edge cases.
  
# `jq`

You use the Python bindings for jq, which slightly differs in the output format.
More specifically, we assume that expression `e` is invoked on a JSON object `i` as `jq.all(e, i)` or `jq.compile(e).input_value(i).all()` (both are equivalent).
This causes the output to be wrapped in a JSON list, even if the result is a single value.

Some examples:
- `jq.all(".foo", {"foo": 1}) == [1]`
- `jq.all("{b: .a + 1}", {"a": 1}) == [{"b": 2}]`
- `jq.all(".[0]", [{"a": 1}, {"b": 2}]) == [{"a": 1}]`
- `jq.all("map(.[0])", [[1, 2], [3, 4]]) == [[1, 3]]`
""".strip()


def convert(
    expressions: list[str],
    utterance: str,
    data: list[dict] = None,
    iterations: int = 8,
    verbose: bool = True,
) -> dict:
    """Convert a raw expression to jq expressions."""
    environment = JsonPathEnvironment()
    prompt = system
    agent = Agent(
        model=model,
        system=prompt,
        tools=[AddTestTool(), RunTestTool()],
        console=Console(quiet=not verbose),
        iterations=iterations,
    )
    response = agent.run(
        format_question(
            expressions,
            utterance,
            data=data,
        ),
        environment=environment,
    )
    response_message = response.message or ""
    response_candidates: list[str] = [
        e.strip() for e in re.findall(r"<jq>(.+?)</jq>", response_message)
    ]
    executions = dict()
    execution_environment = JsonPathEnvironment()
    execution_tool = RunTestTool()
    execution_environment.tests = environment.tests
    for candidate in environment.candidates + response_candidates:
        candidate_result = execution_tool.execute(
            {"jq": candidate}, execution_environment
        ).values
        executions[candidate] = {
            test_name: test_value.split(":")[0]
            for test_name, test_value in candidate_result.items()
        }
    return {
        "environment": {
            "candidates": environment.candidates,
            "tests": environment.tests,
            "tests_old": environment.tests_old,
        },
        "response": {
            "message": response_message,
            "candidates": response_candidates,
        },
        "executions": executions,
    }


def is_valid(annotated: dict):
    if any(annotated["invalid"].values()):
        return False
    if annotated["other_task"] is not None:
        return False
    if annotated["other_environment"] is not None:
        return False
    return True


if __name__ == "__main__":

    # fmt: off
    parser = ArgumentParser()
    parser.add_argument("--page", type=int, default=1)
    parser.add_argument("--model", type=str, default="Gpt41", choices=Models.names(ModelSupports.Functions))
    parser.add_argument("--cache", type=str, default="auto")
    parser.add_argument("--failed", action="store_true", default=False)
    parser.add_argument("--iterations", type=int, default=32)
    parser.add_argument("--debug", type=str, default=None)
    parser.add_argument("--force", action="store_true", default=False)
    args = parser.parse_args()
    # fmt: on

    debugging = args.debug not in {None, False, 0, "0", "false", "False"}
    root = Path("E:/Papers/jqBench")
    root_so = root / "data" / "stackoverflow"

    # load cache
    if args.cache == "auto":
        args.cache = (
            root
            / "Caches"
            / "convert"
            / f"{args.model.lower()}"
            / f"page-{args.page:02d}.json"
        )

    # get input and output file
    if not args.failed:
        file_i = root_so / "2_annotated" / f"page-{args.page:02d}.json"
        file_o = root_so / "3_converted" / f"page-{args.page:02d}.json"
    else:
        file_i = root_so / "3b_failed" / f"page-{args.page:02d}.json"
        file_o = root_so / "4b_converted" / f"page-{args.page:02d}.json"

    # load model
    model_spec = Models.by_name(args.model)
    model_cache = JsonCache(args.cache, autosave=32) if args.cache else None
    model = ChatModel(model=model_spec, client=Client(), cache=model_cache)

    # load and filter to-do
    if file_o.exists() and not args.force and not debugging:
        data_target = json.load(open(file_o, encoding="utf-8"))
        done = set(t["identifier"] for t in data_target)
    else:
        data_target = list()
        done = set()
        file_o.parent.mkdir(parents=True, exist_ok=True)
    todo = json.load(open(file_i, encoding="utf-8"))
    if done:
        todo = [p for p in todo if p["identifier"] not in done]
    if debugging:
        todo = [p for p in todo if str(p["identifier"]) == args.debug]
    todo = [p for p in todo if is_valid(p["annotated"])]

    if len(todo) == 0:
        print(f"Page {args.page} is all done.")
        sys.exit(0)

    # process
    pbar = tqdm(total=len(todo), desc=f"Processing {args.page}", disable=debugging)
    for e in todo:
        pbar.set_description(f"{e['identifier']}")
        e["converted"] = convert(
            e.get("expressions", list()),
            e.get("utterance", None),
            data=e.get("data", list()),
            verbose=args.debug is not None,
        )
        data_target.append(e)
        if not debugging:
            with open(file_o, "w", encoding="utf-8") as f:
                o = json.dumps(
                    sorted(data_target, key=lambda x: x["identifier"]), indent=2
                )
                f.write(o)
        pbar.update(1)

    model_cache.save()
