import numpy as np
import pandas as pd

from utils import memorization_order as mem_ord


def test_stable_memorization_order():
    result = pd.DataFrame(
        {
            "correct": [
                False,
                False,
                False,
                True,
                True,
                False,
                True,
                False,
                False,
                False,
                False,
                True,
                False,
                True,
                False,
            ]
        },
        index=pd.MultiIndex.from_tuples(
            [
                (0, 0, "c"),
                (0, 0, "a"),
                (0, 0, "b"),
                (0, 0, "c"),
                (0, 0, "a"),
                (1, 0, "c"),
                (1, 0, "a"),
                (1, 0, "b"),
                (1, 0, "c"),
                (1, 0, "a"),
                (2, 0, "c"),
                (2, 0, "a"),
                (2, 0, "b"),
                (2, 0, "c"),
                (2, 0, "a"),
            ],
            names=["epoch", "string", "character"],
        ),
    )
    expected = pd.Series(
        [-1, 1, -1, 2, -1],
        index=pd.Index(["c", "a", "b", "c", "a"], name="character"),
        name="stable_mem_epoch",
        dtype=int,
    )
    actual = mem_ord.stable_memorization_order(result)
    pd.testing.assert_series_equal(actual, expected)


def test_initial_memorization_order():
    result = pd.DataFrame(
        {
            "correct": [
                False,
                False,
                False,
                True,
                True,
                False,
                True,
                False,
                False,
                False,
                False,
                True,
                False,
                True,
                False,
            ]
        },
        index=pd.MultiIndex.from_tuples(
            [
                (0, 0, "c"),
                (0, 0, "a"),
                (0, 0, "b"),
                (0, 0, "c"),
                (0, 0, "a"),
                (1, 0, "c"),
                (1, 0, "a"),
                (1, 0, "b"),
                (1, 0, "c"),
                (1, 0, "a"),
                (2, 0, "c"),
                (2, 0, "a"),
                (2, 0, "b"),
                (2, 0, "c"),
                (2, 0, "a"),
            ],
            names=["epoch", "string", "character"],
        ),
    )
    expected = pd.Series(
        [-1, 1, -1, 0, 0],
        index=pd.Index(["c", "a", "b", "c", "a"], name="character"),
        name="initial_mem_epoch",
        dtype=int,
    )
    actual = mem_ord.initial_memorization_order(result)
    pd.testing.assert_series_equal(actual, expected)


def test_prefix_agreement():
    result = pd.DataFrame(
        {"correct": np.zeros(12, dtype=bool)},
        index=pd.MultiIndex.from_tuples(
            [
                (0, 0, "a"),
                (0, 0, "b"),
                (0, 0, "c"),
                (0, 0, "b"),
                (0, 0, "c"),
                (0, 0, "a"),
                (1, 0, "a"),
                (1, 0, "b"),
                (1, 0, "c"),
                (1, 0, "b"),
                (1, 0, "c"),
                (1, 0, "a"),
            ],
            names=["epoch", "string", "character"],
        ),
    )

    length_1_expected = pd.DataFrame(
        {
            "agreement": [0, 1, 0, 1, 0],
            "disagreement": [0, 0, 1, 0, 1],
        },
        index=pd.Index(["b", "c", "b", "c", "a"], name="token"),
        dtype=int,
    )
    length_1_actual = mem_ord.prefix_agreement(result, 1)
    pd.testing.assert_frame_equal(length_1_actual, length_1_expected)

    length_2_expected = pd.DataFrame(
        {
            "agreement": [0, 0, 0, 0],
            "disagreement": [0, 1, 0, 1],
        },
        index=pd.Index(["c", "b", "c", "a"], name="token"),
        dtype=int,
    )
    length_2_actual = mem_ord.prefix_agreement(result, 2)
    pd.testing.assert_frame_equal(length_2_actual, length_2_expected)
