import marimo

__generated_with = "0.14.13"
app = marimo.App(width="medium")


@app.cell
def _():
    import marimo as mo
    import altair as alt
    import polars as pl
    from notebooks import plotting_utils as pu
    return alt, pl, pu


@app.cell
def _():
    from logit_lens_compositionality.tasks import Task
    return


@app.cell
def _():
    from configs.llama_3_1b.lens import sweep
    return (sweep,)


@app.cell
def _(pl, pu, sweep):
    results = sweep.results()
    results = pu.merge_columns(results, pu.node_labels.keys(), "node", "value")
    results = (
        results.pivot(index=["task_name", "node"], on="correctness", values="value")
        .rename({"true": "correct", "false": "incorrect"})
        .drop_nulls(["correct", "incorrect"])
        .with_columns(relative_rank=(pl.col("correct") / pl.col("incorrect")) - pl.lit(1)).drop("correct", "incorrect")
    )
    # results = pu.merge_columns(results, ["correct", "incorrect"], "correctness", "value")
    return (results,)


@app.cell
def _(results):
    results
    return


@app.cell
def _(alt, pl, results):
    _results = results.filter(pl.col("node").is_in(["x", "GFx"]))

    (alt.Chart(_results).mark_bar().encode(x="node", y="mean(delta)", color="node")) + (
        alt.Chart(_results).mark_errorbar(extent="stdev").encode(x="node", y="delta", color="node")
    )

    # _tasks = [t for t in sweep.tasks if Task(task_name=t).correct_intermediate_nodes == ["Fx", "Gx"]]
    # _results = results.filter(pl.col("task_name").is_in(_tasks))

    # bars = alt.Chart(_results).mark_bar().encode(
    #     x='correctness',
    #     y=alt.Y('mean(value):Q', title='Mean Value'),
    #     color='correctness:N'
    # )

    # # Error bars (mean ± stdev)
    # error_bars = alt.Chart(_results).mark_errorbar(extent='stdev').encode(
    #     x='correctness:N',
    #     y='value:Q',
    #     color='correctness:N'
    # )

    # # Combine
    # chart = bars + error_bars
    # chart.interactive()
    return


@app.cell
def _(alt, results):
    _results = results
    # _tasks = [t for t in sweep.tasks if Task(task_name=t).correct_intermediate_nodes == ["Fx", "Gx"]]
    # _results = results.filter(pl.col("task_name").is_in(_tasks))

    alt.Chart(_results).mark_bar().encode(
        x="node", xOffset="correctness", y="value", column="task_name", color="correctness"
    ).properties(width=100)
    return


@app.cell
def _():
    # For all tasks, x (correct vs incorrect)

    # Compare tasks where Fx and Gx are both correct
    # Compare tasks where Fx is correct and Gx is incorrect
    # Compare FGx and GFx
    return


@app.cell
def _():
    return


if __name__ == "__main__":
    app.run()
