import marimo

__generated_with = "0.15.3"
app = marimo.App(width="medium")


@app.cell
def _():
    import marimo as mo
    import altair as alt
    import polars as pl
    import itertools
    import os
    from notebooks import plotting_utils as pu
    return alt, mo, pl, pu


@app.cell
def _():
    from configs.llama_3_3b.lens import sweep
    return (sweep,)


@app.cell
def _(alt, pu):
    def build_chart(experiment, example_idx: int):
        _df = experiment.step_result("processing_signatures")[example_idx]
        _df = _df.with_columns(node_type=pu.map_column(col="node_type", map=pu.node_labels))

        node_labels, node_colors, node_dash = pu.node_properties(task_name=experiment.task_name)

        _chart = (
            alt.Chart(_df, title=experiment.task_name)
            .mark_line(strokeWidth=4)
            .encode(
                x=alt.X("layer", axis=alt.Axis(title="Layer")),
                y=alt.Y("reciprocal_rank", axis=alt.Axis(title="Reciprocal Rank")),
                color=alt.Color(
                    "node_type",
                    sort=node_labels,
                    title="",
                    scale=alt.Scale(
                        domain=node_labels,
                        range=node_colors,
                    ),
                    legend=alt.Legend(orient="top"),
                ),
                strokeDash=alt.StrokeWidth(
                    "node_type",
                    sort=node_labels,
                    title="",
                    scale=alt.Scale(
                        domain=node_labels,
                        range=node_dash,
                    ),
                ),
                tooltip="reciprocal_rank",
            )
            .properties(width=600, height=400)
            .configure_title(fontSize=30)
            .configure_axis(labelFontSize=18, titleFontSize=28)
            .configure_legend(labelFontSize=24, symbolStrokeWidth=8)
        )

        return _chart
    return (build_chart,)


@app.cell
def _(mo, sweep):
    task_dropdown = mo.ui.dropdown(options=sweep.tasks, value=sweep.tasks[0])
    correctness_dropdown = mo.ui.dropdown(options=[False, True], value=True)
    task_dropdown, correctness_dropdown
    return correctness_dropdown, task_dropdown


@app.cell
def _(correctness_dropdown, sweep, task_dropdown):
    task_experiment = [
        e
        for e in sweep.experiments
        if e.task_name == task_dropdown.value and e.correctness == correctness_dropdown.value
    ][0]
    task_queries = task_experiment.step_result("selected_queries")
    task_topk_tokens = task_experiment.step_result("top_k_tokens")
    return task_experiment, task_queries, task_topk_tokens


@app.cell
def _(mo, task_experiment):
    example_slider = mo.ui.slider(
        start=0, stop=(len(task_experiment.step_result("processing_signatures")) - 1)
    )
    example_slider
    return (example_slider,)


@app.cell
def _(build_chart, example_slider, task_experiment):
    _chart = build_chart(task_experiment, example_idx=example_slider.value)
    _chart.interactive()
    return


@app.cell
def _(example_slider, task_queries):
    task_queries[example_slider.value].query
    return


@app.cell
def _(example_slider, task_topk_tokens):
    print(
        task_topk_tokens[example_slider.value]
        .sort("prob", descending=True)
        .group_by("layer")
        .first()["token"]
        .to_list()
    )
    return


@app.cell
def _(example_slider, pl, task_topk_tokens):
    print(
        task_topk_tokens[example_slider.value]
        .filter(pl.col("position") == task_topk_tokens[example_slider.value]["position"].max())
        .sort("prob", descending=True)
        .group_by("layer")
        .first()["token"]
        .to_list()
    )
    return


if __name__ == "__main__":
    app.run()
