import pytest
import torch as t

from hypo_interp.config import ExperimentConfig
from hypo_interp.tasks import InductionTask, TracrProportionTask, TracrReverseTask
from hypo_interp.test_executor import TestExecutor

DEVICE = "cpu"


@pytest.fixture(scope="module")
def induction_task():
    """
    We use this fixture to create a single instance of the induction
    task for all tests in this module. Should speed things up a bit.
    """
    device = DEVICE
    task = InductionTask(device=device, seq_len=50, num_examples=40, zero_ablation=True)
    yield task


@pytest.fixture(scope="module")
def tracr_proportion_task():
    """
    We use this fixture to create a single instance of the induction
    task for all tests in this module. Should speed things up a bit.
    """
    device = DEVICE
    task = TracrProportionTask(device=device, seq_len=50, num_examples=50)
    yield task


@pytest.fixture(scope="module")
def tracr_reverse_task():
    """
    We use this fixture to create a single instance of the induction
    task for all tests in this module. Should speed things up a bit.
    """
    device = DEVICE
    task = TracrReverseTask(device=device, seq_len=50, num_examples=6)
    yield task


def test_faithfulness_test(induction_task):
    """
    We test for faithfulness against the complete circuit which should
    be essentially perfect.
    """
    with t.no_grad():
        device = DEVICE
        config = ExperimentConfig(
            device=device,
            random_proportion=0.1,
            num_random_circuits=10,
            seed=42,
        )

        induction_task.reset()  # reset the task because it is being used by other tests
        candidate_circuit = induction_task.complete_circuit

        test_executor = TestExecutor(
            task=induction_task, config=config, candidate_circuit=candidate_circuit
        )

        (
            p_val,
            real_eval_metric,
            simulated_eval_metrics,
            _,
        ) = test_executor.test_faithfulness()

        # The full circuit should be perfect
        assert abs(real_eval_metric - 0.0) < 1e-5
        assert p_val < 0.5, f"The p-value should be less than 0.5.  {p_val:.04f}"

        # All the eval metrics for the random circuits should be much worse
        for simulated_eval_metric in simulated_eval_metrics:
            assert (
                simulated_eval_metric > real_eval_metric
            ), "The random circuits should be much worse than the full circuit."


def test_faithfulness_test_tracr_proportion_test(tracr_proportion_task):
    """
    We test for faithfulness against the complete circuit which should
    be essentially perfect.
    """
    with t.no_grad():
        device = DEVICE
        config = ExperimentConfig(
            device=device,
            random_proportion=None,
            num_random_circuits=10,
            seed=42,
            invert=False,
        )

        tracr_proportion_task.reset()  # reset the task because it is being used by other tests
        candidate_circuit = tracr_proportion_task.complete_circuit

        test_executor = TestExecutor(
            task=tracr_proportion_task,
            config=config,
            candidate_circuit=candidate_circuit,
        )
        (
            p_val,
            real_eval_metric,
            simulated_eval_metrics,
            _,
        ) = test_executor.test_faithfulness(quantile=0.1)
        # The full circuit should be perfect
        assert abs(real_eval_metric - 0.0) < 1e-5
        assert (
            p_val < 0.5
        ), f"The p-value should be less than 0.5. p-value is {p_val:.04f}"

        # All the eval metrics for the random circuits should be much worse
        for simulated_eval_metric in simulated_eval_metrics:
            assert (
                simulated_eval_metric > real_eval_metric
            ), "The random circuits should be much worse than the full circuit."


def test_faithfulness_test_tracr_reverse_test(tracr_reverse_task):
    """
    We test for faithfulness against the complete circuit which should
    be essentially perfect.
    """
    with t.no_grad():
        device = DEVICE
        config = ExperimentConfig(
            device=device,
            random_proportion=0.5,
            num_random_circuits=10,
            seed=42,
        )

        tracr_reverse_task.reset()  # reset the task because it is being used by other tests
        candidate_circuit = tracr_reverse_task.complete_circuit

        test_executor = TestExecutor(
            task=tracr_reverse_task, config=config, candidate_circuit=candidate_circuit
        )
        (
            p_val,
            real_eval_metric,
            simulated_eval_metrics,
            _,
        ) = test_executor.test_faithfulness()

        # The full circuit should be perfect
        assert abs(real_eval_metric - 0.0) < 1e-5
        assert p_val < 0.5, f"The p-value should be zero. p-value is {p_val:.04f}"

        # All the eval metrics for the random circuits should be much worse
        for simulated_eval_metric in simulated_eval_metrics:
            assert (
                simulated_eval_metric > real_eval_metric
            ), "The random circuits should be much worse than the full circuit."


def test_permutation_test(induction_task):
    """
    We test for faithfulness against the complete circuit which should
    be essentially perfect.
    """
    with t.no_grad():
        device = DEVICE
        config = ExperimentConfig(
            device=device,
            random_proportion=1,
            num_random_circuits=10,
            seed=42,
        )

        induction_task.reset()  # reset the task because it is being used by other tests
        candidate_circuit = induction_task.complete_circuit

        test_executor = TestExecutor(
            task=induction_task, config=config, candidate_circuit=candidate_circuit
        )

        candidate_score, _ = test_executor.compute_candidate_score(per_prompt=True)
        p_val, _, _ = test_executor.test_sufficiency(t.zeros_like(candidate_score))
        assert p_val > 1 - 1e-9, "The p-value should be one."
