import argparse
import os
import sys
from pathlib import Path

sys.path.append(
    os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "inference",
    )
)

import json
from collections import defaultdict

from benchmark import Benchmark, Dataset
from execution import JqExecutionEngine
from pydantic import JsonValue
from tqdm import tqdm


def is_success(converted) -> bool:
    executions = converted["executions"]
    remove_expression: set[str] = set()
    remove_test: set[str] = set()
    for expression, results in executions.items():
        expression_success = False
        if len(results) == 0:
            remove_expression.add(expression)
            continue
        reverse = defaultdict(list)
        for k, v in results.items():
            reverse[v].append(k)
        if all(r in {"Success", "Partial success"} for r in reverse):
            expression_success = True
        if len(reverse) == 2:
            if (len(reverse.get("Success", list())) > 1) or len(
                reverse.get("Partial success", list())
            ) > 1:
                if len(failed := reverse.get("Failed", list())) == 1:
                    expression_success = True
                    remove_test.update(failed)
                if len(error := reverse.get("Error", list())) == 1:
                    expression_success = True
                    remove_test.update(error)
        if not expression_success:
            remove_expression.add(expression)
    if len(remove_expression) == len(executions):
        return False
    if len(remove_test) > 1:
        return False
    for r in remove_expression:
        executions.pop(r, None)
    for r in executions.values():
        for t in remove_test:
            r.pop(t, None)
    for t in remove_test:
        converted["environment"]["tests"].pop(t, None)
    return True


def _collect_leaves(o: JsonValue) -> set:
    leaves = set()
    if isinstance(o, dict):
        for v in o.values():
            leaves.update(_collect_leaves(v))
    elif isinstance(o, list):
        for v in o:
            leaves.update(_collect_leaves(v))
    else:
        leaves.add(o)
    return leaves


def _collect_keys(o: JsonValue) -> set:
    keys = set()
    if isinstance(o, dict):
        for k, v in o.items():
            keys.add(k)
            keys.update(_collect_keys(v))
    elif isinstance(o, list):
        for v in o:
            keys.update(_collect_keys(v))
    return keys


def _is_subset(i: JsonValue, o: JsonValue) -> bool:
    """Check if output `o` is a subset of input `i`."""
    if isinstance(i, dict) and isinstance(o, dict):
        return all(k in i and _is_subset(i[k], v) for k, v in o.items())
    if isinstance(i, list) and isinstance(o, list):
        return all(any(_is_subset(ie, oe) for ie in i) for oe in o)
    return i == o


def _is_extract(i: JsonValue, o: JsonValue) -> bool:
    """Check if all leaf values in output `o` are present in input `i`."""
    leaves_i = _collect_leaves(i) - {None}
    leaves_o = _collect_leaves(o) - {None}
    return leaves_o.issubset(leaves_i)


def _is_transform(i: JsonValue, o: JsonValue) -> bool:
    """Check if `o` has exactly the same structure as `i` while allowing any leaf values to differ."""
    if isinstance(i, dict):
        if not isinstance(o, dict):
            return False
        if i.keys() != o.keys():
            return False
        return all(_is_transform(i[k], o[k]) for k in i.keys())
    if isinstance(i, list):
        if not isinstance(o, list):
            return False
        if len(i) != len(o):
            return False
        return all(_is_transform(iv, ov) for iv, ov in zip(i, o))
    return not isinstance(o, (dict, list))


def _is_keys(i: JsonValue, o: JsonValue) -> bool:
    """Check if `o` only contains keys from `i` as values."""
    leaves = _collect_leaves(o) - {None}
    keys = _collect_keys(i)
    if all(isinstance(l, str) and l in keys for l in leaves):
        return True
    return False


def get_task(inputs: list[JsonValue], outputs: list[JsonValue]) -> str:
    if all(len(o) == 1 for o in outputs):
        outputs = [o[0] for o in outputs]
    if len(inputs) != len(outputs):
        return "N/A"
    if all(_is_subset(i, o) for i, o in zip(inputs, outputs)):
        return "subset"
    if all(_is_subset(o, i) for i, o in zip(inputs, outputs)):
        return "superset"
    if all(_is_extract(i, o) for i, o in zip(inputs, outputs)):
        return "extract"
    if all(_is_extract(o, i) for i, o in zip(inputs, outputs)):
        return "augment"
    if all(_is_transform(i, o) for i, o in zip(inputs, outputs)):
        return "transform"
    if all(_is_keys(i, o) for i, o in zip(inputs, outputs)):
        return "keys"
    return "other"


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--converted", type=Path, default=None)
    parser.add_argument("--to", type=Path, default=None)
    parser.add_argument("--failed", action="store_true")
    parser.add_argument("--debug", type=int, default=None)
    args = parser.parse_args()

    root = Path("E:/Papers/jqBench") / "data" / "stackoverflow"
    files_i = list((args.converted or (root / "3_converted")).glob("*.json"))
    file_o = args.to or (root / "4_collected" / "stackoverflowValid.json")
    file_o.parent.mkdir(exist_ok=True)
    if args.failed:
        files_failed = root / "5_failed"
        files_failed.mkdir(exist_ok=True)

    dataset = Dataset(name="stackoverflowValid", benchmarks=list())
    progress = tqdm(total=len(files_i))
    for file in files_i:
        file_failed = list()
        data = json.load(file.open(encoding="utf-8"))
        for item in data:
            if args.debug is not None and item["identifier"] != args.debug:
                continue
            if len(item["converted"]["environment"]["tests"]) < 2:
                continue
            success = is_success(item["converted"])
            if not success or len(item["converted"]["environment"]["tests"]) < 2:
                if not success:
                    item.pop("converted", None)
                    file_failed.append(item)
                continue
            progress.set_description(str(item["identifier"]))
            inputs = [
                t["input"] for t in item["converted"]["environment"]["tests"].values()
            ]
            expressions = {
                e: list(results.values()).count("Success")
                for e, results in item["converted"]["executions"].items()
            }
            expression_best = max(expressions, key=expressions.get)
            outputs = [JqExecutionEngine.execute(expression_best, i) for i in inputs]
            task = get_task(inputs, outputs)
            if args.debug is not None:
                print(
                    f"  {item['identifier']}: {task} ({len(expressions)} expressions)"
                )
            dataset.benchmarks.append(
                Benchmark(
                    identifier=item["identifier"],
                    utterance=item["utterance"],
                    expressions=list(expressions.keys()),
                    inputs=inputs,
                    tasks=[task],
                )
            )
        progress.update(1)
        progress.set_postfix(tests=len(dataset.benchmarks), failed=len(file_failed))
        if args.failed and len(file_failed) > 0:
            with (files_failed / file.name).open("w", encoding="utf-8") as f:
                json.dump(file_failed, f, indent=2)

    if file_o.exists():
        print(f"⚠️ Error: output file {file_o} already exists.")
    else:
        if args.debug is None:
            with open(file_o, "w", encoding="utf-8") as f:
                f.write(dataset.model_dump_json(indent=2))
