import json
from enum import Enum
from typing import Optional

from rich.table import Table
from execution import Program, ProgramKind, compiles, run
from pydantic import BaseModel, JsonValue
from settings import Settings


class EvaluationSettings(BaseModel):
    keys: bool = True
    order: bool = False


class ValueMatch(str, Enum):
    No = "No"
    Exact = "Exact"
    Wrapped = "Wrapped"
    Unwrapped = "Unwrapped"


class Metrics(BaseModel):
    compiles: bool
    executes: bool
    value_match: ValueMatch
    exact_match: bool


MetricNames = Enum(
    "MetricNames", ((v, v) for v in Metrics.model_fields.keys()), type=str
)


def aggregate(metrics: list[Metrics]) -> dict[MetricNames, float]:
    aggregated = dict()
    for name in MetricNames:
        values = [getattr(m, name.value) for m in metrics]
        if all(isinstance(v, bool) for v in values):
            aggregated[name] = sum(values) / len(values)
        if all(isinstance(v, ValueMatch) for v in values):
            aggregated[name] = sum(
                v in (ValueMatch.Wrapped, ValueMatch.Exact, ValueMatch.Unwrapped)
                for v in values
            ) / len(values)
    return aggregated


def evaluate(
    program: Program,
    inputs: list[JsonValue],
    solutions: list[str],
    settings: Optional[EvaluationSettings] = None,
) -> Metrics:
    settings = settings or EvaluationSettings()
    result = Metrics(
        compiles=False, executes=False, value_match=ValueMatch.No, exact_match=False
    )
    # test if compiles
    result.compiles = compiles(program)
    if not result.compiles:
        return result
    # check if executes
    output = [run(program, i) for i in inputs]
    result.executes = all(o is not None for o in output)
    if not result.executes:
        return result
    # exact match
    result.exact_match = any(program.code == solution for solution in solutions)
    # check if values match
    matches = list()
    for s in solutions:
        expect = [run(Program(kind=ProgramKind.Jq, code=s), i) for i in inputs]
        match = [equals(e, o, settings) for e, o in zip(expect, output)]
        table = Table("Input", "Expected", "Output", "Match")
        for i, e, o, m in zip(inputs, expect, output, match):
            table.add_row(
                json.dumps(i)[:256], json.dumps(e)[:256], json.dumps(o)[:256], m
            )
            table.add_section()
        Settings.console.print(table)
        if len(set(match)) == 1:
            matches.append(match[0])
        else:
            matches.append(ValueMatch.No)
    result.value_match = sorted(
        matches,
        key=lambda v: {
            ValueMatch.No: 0,
            ValueMatch.Unwrapped: 1,
            ValueMatch.Wrapped: 2,
            ValueMatch.Exact: 3,
        }[v],
        reverse=True,
    )[0]
    return result


def equals(
    expected: JsonValue, actual: JsonValue, settings: EvaluationSettings = None
) -> ValueMatch:
    settings = settings or EvaluationSettings()

    if not settings.keys:

        def canonical(o: JsonValue) -> dict:
            if isinstance(o, dict):
                return sorted(
                    [canonical(ov) for ov in o.values()],
                    key=lambda x: json.dumps(x, sort_keys=True),
                )
            if isinstance(o, list):
                c = [canonical(ov) for ov in o]
                if not settings.order:
                    c.sort(key=lambda x: json.dumps(x, sort_keys=True))
                return c
            return o

        expected = canonical(expected)
        actual = canonical(actual)

    # by default, check exact match
    try:
        if json.dumps(expected, sort_keys=True) == json.dumps(actual, sort_keys=True):
            return ValueMatch.Exact
        if json.dumps(expected, sort_keys=True) == json.dumps([actual], sort_keys=True):
            return ValueMatch.Wrapped
        if json.dumps([expected], sort_keys=True) == json.dumps(actual, sort_keys=True):
            return ValueMatch.Unwrapped
    except TypeError:
        return ValueMatch.No
    return ValueMatch.No
