import marimo

__generated_with = "0.12.10"
app = marimo.App(width="medium")


@app.cell
def _():
    import marimo as mo
    import altair as alt
    import polars as pl
    from notebooks import plotting_utils as pu
    from dataclasses import asdict
    return alt, asdict, mo, pl, pu


@app.cell
def _():
    from configs.evaluation_by_model import sweep
    return (sweep,)


@app.cell
def _(mo, sweep):
    model_dropdown = mo.ui.dropdown(options=sweep.models, value=sweep.models[0])

    model_dropdown
    return (model_dropdown,)


@app.cell
def _(mo, sweep):
    task_dropdown = mo.ui.dropdown(options=sweep.tasks, value=sweep.tasks[0])

    task_dropdown
    return (task_dropdown,)


@app.cell
def _(sweep):
    experiments = {(e.model_name, e.task_name): e for e in sweep.experiments}
    return (experiments,)


@app.cell
def _(experiments, model_dropdown, task_dropdown):
    e = experiments[(model_dropdown.value, task_dropdown.value)]
    return (e,)


@app.cell
def _(e, mo):
    hop_dropdown = mo.ui.dropdown(options=e.hops, value=e.hops[-1])

    hop_dropdown
    return (hop_dropdown,)


@app.cell
def _(asdict, e, hop_dropdown, pl):
    _query, _pred = hop_dropdown.value

    _queries = pl.DataFrame([asdict(icq.query) for icq in e.step_result("sampled_queries")])
    _queries = _queries[["x", "Fx", "GFx", "Gx", "FGx"]]
    _queries = _queries.select(
        [_col for _col in _queries.columns if not _queries[_col].is_null().all()]
    )

    _predictions = pl.DataFrame(
        [
            asdict(p) | {"correct": p.pred == p.label}
            for p in e.step_result(f"results_{_query}_{_pred}")["predictions"]
        ]
    )
    _predictions = _predictions[["pred", "label", "correct", "prompt"]]

    pl.concat([_queries, _predictions], how="horizontal")
    return


if __name__ == "__main__":
    app.run()
