from concurrent.futures import ThreadPoolExecutor, as_completed
import json
from pathlib import Path
from typing import Optional

from benchmark import (
    Benchmark,
    Experiment,
    InputConfiguration,
    InputKind,
    InputSet,
    ModelConfiguration,
    Prediction,
    Solution,
    load,
)
from evaluation import evaluate
from jsonargparse import ArgumentParser, lazy_instance
from llm import ConfiguredModel
from redacted.client import RateLimitError
from rich.console import Console
from rich.table import Table
from settings import Settings, Verbosity
from solvers.solver import Solver
from solvers.solver_jq import JqSampler
from tqdm import tqdm
from utilities import get_root

if __name__ == "__main__":

    # fmt: off
    parser = ArgumentParser()
    parser.add_argument("--benchmarks", type=Path, default=None)
    parser.add_argument("--select", type=str, default=None)
    parser.add_argument("--output", type=str, default=None)
    parser.add_argument("--output-root", type=Path, default=None)
    parser.add_argument("--solver", type=Solver, default=lazy_instance(JqSampler))
    parser.add_argument("--input", type=InputConfiguration, default=InputConfiguration(kind=InputKind.InputOutput, which=InputSet.AllButOne))
    parser.add_argument("--model", type=ModelConfiguration, default=ModelConfiguration(name="Gpt41Mini"))
    parser.add_argument("--suffix", type=str, default="")
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--cache", type=str, default=None)
    parser.add_argument("--cache-root", type=Path, default=None)
    parser.add_argument("--force", action="store_true")
    parser.add_argument("--take", type=Optional[int], default=None)
    parser.add_argument("--verbosity", type=Verbosity, default=Verbosity.Off)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    args_init = parser.instantiate_classes(args)
    # fmt: on

    # # apply settings
    Settings.verbosity = args.verbosity
    Settings.console = Console(quiet=(Settings.verbosity == Verbosity.Off))

    # load classes
    solver: Solver = args_init.solver
    solver_name = f"{solver.name}{args.suffix}"
    solver_parameters = {"name": solver_name, **solver.parameters}
    input_configuration: InputConfiguration = args_init.input
    model_configuration: ModelConfiguration = args_init.model

    # load data
    dataset = load(
        args.benchmarks
        or Path("data/stackoverflow/6_filtered/stackoverflowFiltered.json"),
        debug=args.select,
    )

    # set defaults
    auto1 = input_configuration.name
    auto2 = f"{dataset.name}-{solver_name}-{model_configuration.name}"
    if args.cache == "auto":
        args.cache_root = args.cache_root or (get_root() / "caches")
        args.cache = args.cache_root / auto1 / f"{auto2}.json"
    if args.output == "auto":
        args.output_root = args.output_root or (get_root() / "results")
        args.output = args.output_root / auto1 / f"{auto2}.json"
    if isinstance(args.output, str) and not args.output.endswith(".json"):
        args.output = Path(args.output) / auto1 / f"{auto2}.json"
    print(f"Using cache file: {args.cache}")
    print(f"Using output file: {args.output}")

    # load model
    model = ConfiguredModel(model_configuration, args.cache)

    # load existing results
    output = None
    experiment = None
    if args.output is not None:
        if (output := Path(args.output)).exists() and not args.force:
            with open(output, encoding="utf-8") as f:
                experiment = Experiment.model_validate_json(f.read())
            if not (
                experiment.dataset == dataset.name
                or experiment.solver == solver_name
                or experiment.solver_parameters == solver.parameters
                or experiment.input_parameters == input_configuration
            ):
                raise ValueError("Incompatible experiment configurations.")
        else:
            output.parent.mkdir(parents=True, exist_ok=True)
    experiment = experiment or Experiment(
        dataset=dataset.name,
        solver={"name": solver_name, **solver.parameters},
        input=input_configuration,
        model=model_configuration,
    )

    # filter out completed benchmarks
    todo = dataset.benchmarks
    if not args.force and (args.select is None):
        done = {r.identifier for r in experiment.solutions}
        todo = [d for d in todo if d.identifier not in done]
    if args.take:
        todo = todo[: args.take]

    # remove to-do from experiment solutions
    experiment.solutions = [
        s
        for s in experiment.solutions
        if s.identifier not in {b.identifier for b in todo}
    ]

    if len(todo) >= 10 and args.output is None:
        exit("Refusing to run large experiments without output file.")

    if args.select is not None:
        Settings.verbosity = Verbosity.Debug
        Settings.console.quiet = False
        args.debug = True

    # define solver
    def solve(benchmark: Benchmark, solver: Solver, model: ConfiguredModel) -> Solution:
        problem = benchmark.problem(input_configuration)
        candidates, meta = solver.solve(problem, model)
        inputs = benchmark.inputs or [
            json.loads(benchmark.inputfile.read_text()).get("data", {})
        ]
        predictions = [
            Prediction(
                program=candidate,
                metrics=evaluate(
                    program=candidate,
                    inputs=inputs,
                    solutions=benchmark.expressions,
                    settings=benchmark.settings,
                ),
            )
            for candidate in candidates
        ]
        return Solution(
            identifier=benchmark.identifier,
            inputs=benchmark.inputs,
            inputfile=benchmark.inputfile,
            solutions=benchmark.expressions,
            predictions=predictions,
            metadata=meta,
        )

    # initialize progress bar
    progress = tqdm(
        total=len(todo),
        disable=len(todo) <= 1,
        desc=f"{solver_name} - {model_configuration.name} - {input_configuration.name}",
    )

    with ThreadPoolExecutor(max_workers=args.workers) as executor:
        futures = {
            executor.submit(solve, benchmark, solver, model): benchmark
            for benchmark in todo
        }
        for future in as_completed(futures):
            benchmark = futures[future]
            try:
                solution = future.result()
            except RateLimitError as e:
                executor.shutdown(wait=False, cancel_futures=True)
                print(f"Rate limit error solving {benchmark.identifier}: {e}")
                raise e
            except Exception as e:
                progress.update(1)
                if args.debug or "overloaded_error" in str(e):
                    raise
                print(f"Error solving {benchmark.identifier}: {e}")
                raise
                continue
            experiment.solutions.append(solution)
            if Settings.verbosity == Verbosity.Debug:
                table = Table(
                    "Program", "Compiles", "Executes", "Value Match", "Exact Match"
                )
                for p in solution.predictions:
                    table.add_row(
                        p.program.code,
                        str(p.metrics.compiles),
                        str(p.metrics.executes),
                        str(p.metrics.value_match),
                        str(p.metrics.exact_match),
                    )
                    table.add_section()
                Settings.console.print("> 💯")
                Settings.console.print(table)
            if (output is not None) and (progress.n % 10 == 0):
                o = experiment.model_dump_json(indent=2)
                with open(output, "w", encoding="utf-8") as f:
                    f.write(o)
            progress.update(1)

    if (args.select is None) and (output is not None):
        o = experiment.model_dump_json(indent=2)
        with open(output, "w", encoding="utf-8") as f:
            f.write(o)
    model.model.cache.save()
