import os
import sys

sys.path.append(
    os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "inference",
    )
)

from collections import defaultdict
from pathlib import Path

from benchmark import Dataset, Experiment
from evaluation import ValueMatch

if __name__ == "__main__":

    root = Path("E:/Papers/jqBench") / "data" / "stackoverflow"

    # defaults from input
    dataset = root / "4_collected" / "stackoverflowValid.json"
    dataset_out = root / "4_collected" / "stackoverflowFiltered.json"
    experiments = [
        "E:/Papers/jqBench/results/kind=InputOutput-which=AllButOne/stackoverflowValid-JqSampler-Gpt41.json",
        "E:/Papers/jqBench/results/kind=InputOutput-which=AllButOne/stackoverflowValid-JqSampler-Gpt41Mini.json",
        "E:/Papers/jqBench/results/kind=InputOutput-which=AllButOne/stackoverflowValid-JqSampler-Phi4.json",
    ]
    dataset_name = "stackoverflowFiltered"

    # load data
    with open(dataset, encoding="utf-8") as f:
        dataset = Dataset.model_validate_json(f.read())

    filter = defaultdict(set)
    names = list()
    for experiment_file in experiments:
        with open(experiment_file, encoding="utf-8") as f:
            experiment = Experiment.model_validate_json(f.read())
        name = f"{experiment.model.name}-{experiment.input.name}"
        names.append(name)
        for solution in experiment.solutions:
            metrics = [s.metrics for s in solution.predictions]
            if all(m.value_match == ValueMatch.No for m in metrics):
                continue
            filter[name].add(solution.identifier)

    union = set.union(*filter.values())
    intersection = set.intersection(*filter.values())

    print(f"ℹ️  {len(intersection)} solved by all.")
    for name, ids in sorted(filter.items(), key=lambda x: len(x[1]), reverse=True):
        print(f"ℹ️  + {len(ids - intersection)} solved by {name}.")
    print(f"ℹ️  {len(union)} solved by at least one.")
    print(f"ℹ️  {len(dataset.benchmarks) - len(union)} unsolved.")

    dataset_out.parent.mkdir(parents=True, exist_ok=True)
    dataset_new = Dataset(
        name=dataset_name,
        benchmarks=[p for p in dataset.benchmarks if p.identifier not in union],
    )

    if not dataset_out.exists():
        print(
            f"➡️  Writing {len(dataset_new.benchmarks)} benchmarks to {dataset_out.relative_to(root)}."
        )
        with open(dataset_out, "w", encoding="utf-8") as f:
            f.write(dataset_new.model_dump_json(indent=2))
    else:
        print(f"⚠️  File {dataset_out.relative_to(root)} already exists.")
