import json
from enum import Enum
from pathlib import Path
from typing import Annotated, Literal, Optional, Self, Union

from evaluation import Metrics, EvaluationSettings
from execution import JqExecutionEngine, Program
from llm import ModelConfiguration
from redacted import Usage
from agent import ToolCall
from pydantic import BaseModel, Field, JsonValue, model_validator
from utilities import count_values


class InputKind(Enum):
    Empty = "Empty"
    Examples = "Examples"
    Input = "Input"
    InputOutput = "InputOutput"
    Schema = "Schema"


class InputSet(Enum):
    AllButOne = "AllButOne"
    Na = "Na"
    One = "One"


class InputConfiguration(BaseModel):
    kind: InputKind
    which: InputSet

    @property
    def name(self) -> str:
        if self.kind == InputKind.Empty:
            return f"kind={self.kind.value}"
        return f"kind={self.kind.value}-which={self.which.value}"

    @model_validator(mode="after")
    def validate(self) -> Self:
        if self.kind in {InputKind.Empty, InputKind.Schema}:
            self.which = InputSet.Na
        return self


class Input(BaseModel):
    kind: Literal[InputKind.Input] = InputKind.Input
    inputs: list[JsonValue]


class InputOutput(BaseModel):
    kind: Literal[InputKind.InputOutput] = InputKind.InputOutput
    inputs: list[JsonValue]
    outputs: list[JsonValue]


class Schema(BaseModel):
    kind: Literal[InputKind.Schema] = InputKind.Schema
    jsonschema: JsonValue
    jsonoutput: Optional[JsonValue] = None
    dataset: JsonValue


ProblemInput = Annotated[Union[Input, InputOutput, Schema], Field(discriminator="kind")]


class Problem(BaseModel):
    utterance: Optional[str] = None
    input: Optional[ProblemInput] = None


class Benchmark(BaseModel):
    identifier: Union[str, int]
    utterance: str
    expressions: list[str]
    # these are for jqStack
    inputs: Optional[list[JsonValue]] = None
    tasks: list[str] = Field(default_factory=list)
    # these are for jqSpider
    inputfile: Optional[Path] = None
    jsonschema: Optional[JsonValue] = None
    jsonoutput: Optional[JsonValue] = None
    settings: Optional[EvaluationSettings] = None

    def problem(self, configuration: InputConfiguration) -> Problem:
        if configuration.kind is None:
            return Problem(utterance=self.utterance)
        match configuration.which:
            case InputSet.One:
                inputs = self.inputs[:1]
            case InputSet.AllButOne:
                if len(self.inputs) == 2:
                    inputs = self.inputs[:1]
                else:
                    ishort = min(
                        enumerate(self.inputs), key=lambda x: count_values(x[1])
                    )
                    inputs = [ishort[1]] + [
                        i for j, i in enumerate(self.inputs) if j != ishort[0]
                    ][: len(self.inputs) - 2]
        i = None
        u = self.utterance
        match configuration.kind:
            case InputKind.Input:
                i = Input(inputs=inputs)
            case InputKind.InputOutput | InputKind.Examples:
                i = InputOutput(
                    inputs=inputs,
                    outputs=[
                        JqExecutionEngine.execute(self.expressions[0], i)
                        for i in inputs
                    ],
                )
                if configuration.kind == InputKind.Examples:
                    u = None
            case InputKind.Schema:
                d = (
                    json.loads(self.inputfile.read_text())["data"]
                    if self.inputfile
                    else None
                )
                i = Schema(
                    jsonschema=self.jsonschema, dataset=d, jsonoutput=self.jsonoutput
                )
        return Problem(utterance=u, input=i)

    @model_validator(mode="after")
    def validate(self) -> Self:
        if self.inputs is not None and self.inputfile is not None:
            raise ValueError("Benchmark cannot have both 'inputs' and 'inputfile'.")
        return self


class Dataset(BaseModel):
    name: str
    benchmarks: list[Benchmark]


class Prediction(BaseModel):
    program: Program
    metrics: Metrics


class SolutionMetadata(BaseModel):
    usage: list[Usage] = Field(default_factory=list)
    tools: list[list[ToolCall]] = Field(default_factory=list)
    feedback: list[Union[str, bool]] = Field(default_factory=list)


class Solution(BaseModel):
    identifier: Union[str, int]
    inputs: Optional[list[JsonValue]] = None
    inputfile: Optional[Path] = None
    solutions: list[str]
    predictions: list[Prediction]
    metadata: Optional[SolutionMetadata] = None


class Experiment(BaseModel):
    dataset: str
    solver: dict[str, Union[str, int, float]]
    input: InputConfiguration
    model: ModelConfiguration
    solutions: list[Solution] = Field(default_factory=list)


def load(file: Path, debug: Optional[int] = None) -> Dataset:
    data = Dataset.model_validate_json(file.read_text(encoding="utf-8"))
    if debug:
        data.benchmarks = [
            i for i in data.benchmarks if str(i.identifier) == str(debug)
        ]
    return data
