import torch

from hypo_interp.tasks import TracrProportionTask, TracrReverseTask

DEVICE = "cpu"


def test_tracr_proportion_full_circuit():
    """
    A basic test that checks that setting a full circuit doesn't
    modify results in any way.
    """

    task = TracrProportionTask(device=DEVICE)
    default_score, _ = task.score()

    full_circuit = task.complete_circuit
    task.set_circuit(full_circuit)
    score_full_circuit_no_reset, _ = task.score()

    task.reset()
    task.set_circuit(full_circuit)
    score_full_circuit_reset, _ = task.score()

    # Check that all of these values are within machine precision of each other
    assert abs(default_score - score_full_circuit_no_reset) < 1e-15
    assert abs(default_score - score_full_circuit_reset) < 1e-15


def test_tracr_reverse_full_circuit():
    """
    A basic test that checks that setting a full circuit doesn't
    modify results in any way.
    """

    task = TracrReverseTask(device=DEVICE)
    default_score, _ = task.score()

    full_circuit = task.complete_circuit
    task.set_circuit(full_circuit)
    score_full_circuit_no_reset, _ = task.score()

    task.reset()
    task.set_circuit(full_circuit)
    score_full_circuit_reset, _ = task.score()

    # Check that all of these values are within machine precision of each other
    assert abs(default_score - score_full_circuit_no_reset) < 1e-15
    assert abs(default_score - score_full_circuit_reset) < 1e-15


def test_tracr_reverse_canonical_circuit():
    with torch.no_grad():
        task = TracrReverseTask(device=DEVICE)
        default_score, _ = task.score()

        task.set_circuit(task.canonical_circuit)
        score_canonical_circuit, _ = task.score()

        ## TODO: why do we get zero difference between canonical and full?
        assert abs(default_score - score_canonical_circuit) < 1e-15


def test_tracr_proportional_canonical_circuit():
    with torch.no_grad():
        task = TracrProportionTask(device=DEVICE)
        default_score, _ = task.score()

        task.set_circuit(task.canonical_circuit)
        score_canonical_circuit, _ = task.score()
        ## TODO: why do we get zero difference between canonical and full?
        assert abs(default_score - score_canonical_circuit) < 1e-15
