import pytest
import torch

from kge.dataset import OneToManyDataset, TripleDataset, TripleDatasetWithNeg


class TestDataset:
    @pytest.fixture
    def sample_triples(self):
        # Create a small set of triples for testing
        # Format: [subject, relation, object]
        return torch.tensor(
            [
                [0, 0, 1],
                [1, 0, 2],
                [0, 1, 2],
                [2, 1, 0],
            ],
            dtype=torch.long,
        )

    def test_triple_dataset_add_inverse_triples(self, sample_triples):
        # Initialize a TripleDataset with sample triples
        dataset = TripleDataset(sample_triples, split="test")

        # Original dataset properties
        original_len = len(dataset)
        original_relations = dataset.relations.copy()

        # Add inverse triples
        relation_offset = 2  # Number of original relations
        dataset.add_inverse_triples(relation_offset)

        # Verify the number of triples doubled
        assert len(dataset) == 2 * original_len

        # Verify the inverse triples were added correctly
        for i in range(original_len):
            # Original triple
            s, r, o = sample_triples[i]

            # Corresponding inverse triple should be at position original_len + i
            inv_s, inv_r, inv_o = dataset.triples[original_len + i]

            assert inv_s.item() == o.item()  # Subject becomes object
            assert inv_r.item() == r.item() + relation_offset  # Relation ID is offset
            assert inv_o.item() == s.item()  # Object becomes subject

        # Verify relations were updated correctly
        expected_relations = original_relations.union(
            {r + relation_offset for r in original_relations},
        )
        assert dataset.relations == expected_relations

        # Verify sr_to_objects is rebuilt when accessed
        sr_objects = dataset.sr_to_objects
        assert (
            len(sr_objects) == 2 * original_len
        )  # Each triple creates a unique (s,r) pair in this test case

    def test_one_to_many_dataset_add_inverse_triples(self, sample_triples):
        # Initialize a OneToManyDataset with sample triples
        num_entities = 3  # Entities are 0, 1, 2 in our sample
        dataset = OneToManyDataset(sample_triples, num_entities, split="test")

        # Original dataset properties
        original_len = len(dataset)
        original_relations = dataset.relations.copy()

        # Add inverse triples
        relation_offset = 2  # Number of original relations
        dataset.add_inverse_triples(relation_offset)

        # Verify the number of unique (s,r) pairs increased
        assert len(dataset) > original_len

        # Verify relations were updated correctly
        expected_relations = original_relations.union(
            {r + relation_offset for r in original_relations},
        )
        assert dataset.relations == expected_relations

        # Check that inverse (s,r) pairs exist with correct object masks
        # For each original (s,r) pair with object o, there should be an (o,r+offset) pair with object s

        # Build a mapping of (s,r) -> objects from the original dataset
        original_sr_to_o = {}
        for i in range(original_len):
            s = dataset.unique_sr_pairs[i, 0].item()
            r = dataset.unique_sr_pairs[i, 1].item()
            o_mask = dataset.o_masks[i]
            objects = torch.nonzero(o_mask, as_tuple=True)[0].tolist()
            original_sr_to_o[(s, r)] = objects

        # Check inverse pairs exist
        for (s, r), objects in original_sr_to_o.items():
            for o in objects:
                # Find the inverse (o, r+offset) pair
                inverse_found = False
                for i in range(original_len, len(dataset)):
                    inv_s = dataset.unique_sr_pairs[i, 0].item()
                    inv_r = dataset.unique_sr_pairs[i, 1].item()
                    if inv_s == o and inv_r == r + relation_offset:
                        # Check that s is in the objects of this inverse pair
                        o_mask = dataset.o_masks[i]
                        assert o_mask[s].item() == True
                        inverse_found = True
                        break
                assert inverse_found, f"Inverse pair for ({s},{r},{o}) not found"

        # Verify sr_to_objects is rebuilt when accessed
        sr_objects = dataset.sr_to_objects
        assert (
            len(sr_objects) >= 2 * original_len
        )  # At least double, could be more if (s,r) pairs have multiple objects

    def test_triple_dataset_with_neg_add_inverse_triples(self, sample_triples):
        # Create negative samples for testing
        head_neg = torch.tensor(
            [
                [3, 4],  # Negative samples for triple 0
                [3, 4],  # Negative samples for triple 1
                [3, 4],  # Negative samples for triple 2
                [3, 4],  # Negative samples for triple 3
            ],
            dtype=torch.long,
        )

        tail_neg = torch.tensor(
            [
                [5, 6],  # Negative samples for triple 0
                [5, 6],  # Negative samples for triple 1
                [5, 6],  # Negative samples for triple 2
                [5, 6],  # Negative samples for triple 3
            ],
            dtype=torch.long,
        )

        # Initialize a TripleDatasetWithNeg with sample triples and negative samples
        dataset = TripleDatasetWithNeg(
            sample_triples,
            head_neg,
            tail_neg,
            split="test",
            prepare_inverse_negatives=True,
        )

        # Original dataset properties
        original_len = len(dataset)
        original_relations = dataset.relations.copy()

        # Add inverse triples
        relation_offset = 2  # Number of original relations
        dataset.add_inverse_triples(relation_offset)

        # Verify the number of triples doubled
        assert len(dataset) == 2 * original_len

        # Verify the inverse triples were added correctly
        for i in range(original_len):
            # Original triple
            s, r, o = sample_triples[i]

            # Corresponding inverse triple should be at position original_len + i
            inv_s, inv_r, inv_o = dataset.triples[original_len + i]

            assert inv_s.item() == o.item()  # Subject becomes object
            assert inv_r.item() == r.item() + relation_offset  # Relation ID is offset
            assert inv_o.item() == s.item()  # Object becomes subject

        # Verify relations were updated correctly
        expected_relations = original_relations.union(
            {r + relation_offset for r in original_relations},
        )
        assert dataset.relations == expected_relations

        # Verify negative samples were correctly handled
        # For the original triples, we should have tail negatives
        # For the inverse triples, we should have head negatives
        assert dataset.negative_samples.shape[0] == 2 * original_len

        # First half should be tail negatives for original triples
        for i in range(original_len):
            assert torch.equal(dataset.negative_samples[i], tail_neg[i])

        # Second half should be head negatives for inverse triples
        for i in range(original_len):
            assert torch.equal(dataset.negative_samples[original_len + i], head_neg[i])

        # Test __getitem__ to ensure it returns the correct data
        for i in range(len(dataset)):
            s, r, o, neg = dataset[i]
            assert s == dataset.triples[i, 0]
            assert r == dataset.triples[i, 1]
            assert o == dataset.triples[i, 2]
            assert torch.equal(neg, dataset.negative_samples[i])

        # Verify sr_to_objects is rebuilt when accessed
        sr_objects = dataset.sr_to_objects
        assert (
            len(sr_objects) == 2 * original_len
        )  # Each triple creates a unique (s,r) pair in this test case
