import pytest
import torch
from copy import deepcopy

from datasets.dataset_utils.dataset_utils import concat_slices


def test_simple_concat_with_offset():
    s1 = {"x": torch.tensor([0, 3, 5])}  # two graphs: size 3 and 2
    s2 = {"x": torch.tensor([0, 4])}  # one graph: size 4

    result = concat_slices(deepcopy(s1), deepcopy(s2))

    # Expect: s1 unchanged, s2 shifted by last index of s1 (5)
    # => [0,3,5] + [0+5,4+5] (without duplicate 0) => [0,3,5,9]
    assert torch.equal(result["x"], torch.tensor([0, 3, 5, 9]))


def test_multiple_datasets_offset():
    s1 = {"x": torch.tensor([0, 2])}  # 2 nodes
    s2 = {"x": torch.tensor([0, 1, 3])}  # graphs: 1,2 nodes
    s3 = {"x": torch.tensor([0, 5])}  # 5 nodes

    result = concat_slices(deepcopy(s1), deepcopy(s2), deepcopy(s3))
    # s1: [0,2]
    # s2 shifted by 2: [0+2,1+2,3+2] -> [2,3,5], drop leading 2 (since already present): [3,5]
    # s3 shifted by 5: [0+5,5+5] -> [5,10], drop leading 5: [10]
    # final: [0,2,3,5,10]
    assert torch.equal(result["x"], torch.tensor([0, 2, 3, 5, 10]))


def test_multiple_keys_supported():
    s1 = {"x": torch.tensor([0, 2]), "y": torch.tensor([0, 1, 3])}
    s2 = {"x": torch.tensor([0, 1]), "y": torch.tensor([0, 2])}

    result = concat_slices(deepcopy(s1), deepcopy(s2))
    # x: [0,2] + shift 2 -> [0,2,3]
    # y: [0,1,3] + shift 3 -> [0,1,3,5], dropping duplicate
    assert torch.equal(result["x"], torch.tensor([0, 2, 3]))
    assert torch.equal(result["y"], torch.tensor([0, 1, 3, 5]))


def test_mismatched_keys_raises():
    s1 = {"x": torch.tensor([0, 1])}
    s2 = {"y": torch.tensor([0, 2])}
    with pytest.raises(AssertionError):
        concat_slices(s1, s2)


def test_empty_input_raises():
    with pytest.raises(IndexError):
        concat_slices()


def test_original_dicts_unchanged():
    s1 = {"x": torch.tensor([0, 2])}
    s2 = {"x": torch.tensor([0, 3])}
    s1_copy, s2_copy = deepcopy(s1), deepcopy(s2)

    concat_slices(s1, s2)

    assert torch.equal(s1["x"], s1_copy["x"])
    assert torch.equal(s2["x"], s2_copy["x"])
