import pytest
import numpy as np
from collections import defaultdict
from aiau.data.data_manager import DataManager
from aiau.oracles.oracle import Oracle

class DummyOracle(Oracle):
    """A simple oracle that returns incremented indices as noisy targets for testing."""
    def query_target_value(self, data_manager, idx):
        return [data_manager.targets_dict[i] + 0.1 for i in idx]


@pytest.fixture
def synthetic_data():
    indices = np.array([0, 1, 2, 3])
    observations = np.array([[1.0, 2.0],
                             [2.0, 3.0],
                             [3.0, 4.0],
                             [4.0, 5.0]])
    targets = np.array([0.5, 0.6, 0.7, 0.8])
    initially_labelled = [0, 2]
    return indices, observations, targets, initially_labelled


def test_initialisation(synthetic_data):
    indices, observations, targets, initially_labelled = synthetic_data
    dm = DataManager(indices, observations, targets, initially_labelled)

    assert set(dm.labelled_indices) == set(initially_labelled)
    assert len(dm) == len(indices)
    assert dm.full_X.shape == observations.shape
    assert dm.full_y.shape == targets.shape


def test_initialise_adds_noisy_targets_and_labels(synthetic_data):
    indices, observations, targets, initially_labelled = synthetic_data
    dm = DataManager(indices, observations, targets, initially_labelled)
    oracle = DummyOracle()

    dm.initialise(oracle)

    # Check if noisy_targets_dict was populated
    for idx in initially_labelled:
        assert len(dm.noisy_targets_dict[idx]) == 1
        assert np.isclose(dm.noisy_targets_dict[idx][0], targets[idx] + 0.1)


def test_update_noisy_targets(synthetic_data):
    indices, observations, targets, initially_labelled = synthetic_data
    dm = DataManager(indices, observations, targets, initially_labelled)
    oracle = DummyOracle()

    dm.initialise(oracle)
    new_idx = 1
    dm.update_noisy_targets(oracle, new_idx)

    assert len(dm.noisy_targets_dict[new_idx]) == 1
    assert np.isclose(dm.noisy_targets_dict[new_idx][0], targets[new_idx] + 0.1)


def test_construct_noisy_dataset(synthetic_data):
    indices, observations, targets, initially_labelled = synthetic_data
    dm = DataManager(indices, observations, targets, initially_labelled)
    oracle = DummyOracle()
    dm.initialise(oracle)

    X, y = dm.construct_noisy_dataset()

    assert X.shape[1] == 2
    assert y.shape[0] == len(dm.noisy_targets_dict[0]) + len(dm.noisy_targets_dict[2])


def test_update_labelled_indices(synthetic_data):
    indices, observations, targets, initially_labelled = synthetic_data
    dm = DataManager(indices, observations, targets, initially_labelled)
    oracle = DummyOracle()
    dm.initialise(oracle)

    new_indices = [1]
    dm.update_noisy_targets(oracle, new_indices)   
    dm.update_labelled_indices(new_indices, iteration=1)

    assert set(new_indices).issubset(dm.labelled_indices)
    assert dm.labelling_history[1] == new_indices