import pandas as pd

from src.utils.eval import DFAggregator

def test_simple_aggregation():
    aggregator = DFAggregator()

    df1 = pd.DataFrame(data={"a": [1, 2], "b": [3, 4]}, dtype=float)
    df2 = pd.DataFrame(data={"a": [3, 2], "b": [1, 2]}, dtype=float)
    aggregator.append_seed_result(df1)
    aggregator.append_seed_result(df2)
    aggregate = aggregator.get_aggregate()

    expected = pd.DataFrame(data={"a": [2, 2], "b": [2, 3]}, dtype=float)
    assert expected.equals(aggregate)

def test_multilevel_index_aggregation():
    aggregator = DFAggregator()

    index = pd.MultiIndex.from_product([["a", "b"], [1, 2, 3]],)
    df1 = pd.DataFrame(2, index=index, columns=["c1", "c2"], dtype=float)
    df2 = pd.DataFrame(0, index=index, columns=["c1", "c2"], dtype=float)
    aggregator.append_seed_result(df1)
    aggregator.append_seed_result(df2)
    aggregate = aggregator.get_aggregate()

    expected = pd.DataFrame(1, index=index, columns=["c1", "c2"], dtype=float)
    assert expected.equals(aggregate)
