#!/usr/bin/python3
"""
Ensure deterministic behavior of implemented equivalence relations.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import pytest  # type: ignore[import-not-found]

import leon
from leon.core import get_equivalence_relation_options
from leon.embedding import get_embedder


@pytest.mark.parametrize("relation_cls", get_equivalence_relation_options())
@pytest.mark.parametrize("seed", [12345, 67890])
def test_relation_determinism(
    relation_cls: str,
    seed: int,
    task_name: str = "IWPCWarfarin-v0",
    embedder_name: str = "openai/text-embedding-3-small"
) -> None:
    """
    Test the determinism of implemented equivalence relations.
    Input:
        relation_cls: the equivalence relation class to test.
        seed: random seed.
        task_name: the name of the optimization task.
        embedder_name: the name of the embedder to use for the task.
    Returns:
        None.
    """
    embedder = get_embedder(embedder_name)
    task = leon.make(task_name, seed=seed)
    sim = getattr(leon.core, relation_cls)(task=task, embedder=embedder)
    qset1 = leon.core.QuotientSet(sim)
    qset2 = leon.core.QuotientSet(sim)

    num_items = min(256, len(task.train))
    idx1 = np.random.choice(num_items, size=num_items, replace=False)
    idx2 = np.random.choice(num_items, size=num_items, replace=False)
    x1, x2 = [task.train[i] for i in idx1], [task.train[i] for i in idx2]

    qset1.assign(task.reduce(x1), np.ones(num_items, dtype=np.float32))
    qset2.assign(task.reduce(x2), np.ones(num_items, dtype=np.float32))

    assert np.all(
        np.equal(
            np.array(qset1.equivalence_classes),
            np.array(qset2.equivalence_classes)
        )
    )
    assert np.all(
        np.isclose(qset1.fractional_occupancy, qset2.fractional_occupancy)
    )

    class1 = qset1.assign(task.reduce([task.train[-1]]), 1.0 + np.ones(1))[0]
    qset2.assign(task.reduce([task.train[-2]]), np.ones(1))
    class2 = qset2.assign(task.reduce([task.train[-1]]), 1.0 + np.ones(1))[0]
    assert class1 == class2

    s1 = qset1.s_star[qset1.equivalence_classes.index(str(class1))]
    s2 = qset2.s_star[qset1.equivalence_classes.index(str(class2))]
    assert s1 == s2 == 2.0
