# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import numpy as np
import pytest
import tensordict
import torch
from packaging.version import parse as parse_version
from tensordict import TensorDict

from verl import DataProto
from verl.protocol import (
    deserialize_single_tensor,
    deserialize_tensordict,
    serialize_single_tensor,
    serialize_tensordict,
    union_numpy_dict,
    union_tensor_dict,
)


def test_union_tensor_dict():
    obs = torch.randn(100, 10)

    data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100])
    data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100])

    data_with_copied_obs = TensorDict(
        {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]
    )

    union_tensor_dict(data1, data2)
    with pytest.raises(AssertionError):
        union_tensor_dict(data1, data_with_copied_obs)


def test_union_numpy_dict():
    """
    A comprehensive test suite for union_numpy_dict, covering standard use
    cases, N-dimensional arrays, object-dtype arrays, and NaN value handling.
    """
    arr_3d = np.arange(8).reshape((2, 2, 2))
    union_numpy_dict({"a": arr_3d}, {"a": arr_3d})
    arr1 = np.array([1, "hello", np.array([2, 3])], dtype=object)
    arr2 = np.array([1, "hello", np.array([2, 3])], dtype=object)
    union_numpy_dict({"a": arr1}, {"a": arr2})
    # --- Test Case 1: The original test with mixed object/float types ---
    # This test case from the original test file is preserved.
    data = np.random.random(100)
    # This array intentionally mixes float('nan') and the string 'nan'
    nan_data = [float("nan") for _ in range(99)]
    nan_data.append("nan")
    nan_data_arr = np.array(nan_data, dtype=object)

    dict1 = {"a": data, "b": nan_data_arr}
    dict2_same = {"a": data.copy(), "b": nan_data_arr.copy()}
    dict3_different = {"a": np.random.random(100)}

    union_numpy_dict(dict1, dict2_same)  # Should pass
    with pytest.raises(AssertionError):
        union_numpy_dict(dict1, dict3_different)

    # --- Test Case 2: Standard 3D arrays (fixes the core bug) ---
    arr_3d = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
    dict_3d_1 = {"nd_array": arr_3d}
    dict_3d_2_same = {"nd_array": arr_3d.copy()}
    dict_3d_3_different = {"nd_array": arr_3d + 1}

    union_numpy_dict(dict_3d_1, dict_3d_2_same)  # Should pass
    with pytest.raises(AssertionError, match="`nd_array` in tensor_dict1 and tensor_dict2 are not the same object."):
        union_numpy_dict(dict_3d_1, dict_3d_3_different)

    # --- Test Case 3: Nested 2D and 4D object-dtype arrays ---
    sub_arr1 = np.array([1, 2])
    sub_arr2 = np.array([3.0, 4.0])
    # 2D object array
    arr_2d_obj = np.array([[sub_arr1, "text"], [sub_arr2, None]], dtype=object)
    arr_2d_obj_diff = np.array([[sub_arr1, "text"], [sub_arr2, "other"]], dtype=object)

    union_numpy_dict({"data": arr_2d_obj}, {"data": arr_2d_obj.copy()})  # Should pass
    with pytest.raises(AssertionError):
        union_numpy_dict({"data": arr_2d_obj}, {"data": arr_2d_obj_diff})

    # 4D object array to ensure deep recursion is robust
    arr_4d_obj = np.array([[[[sub_arr1]]], [[[sub_arr2]]]], dtype=object)
    arr_4d_obj_diff = np.array([[[[sub_arr1]]], [[[np.array([9, 9])]]]], dtype=object)

    union_numpy_dict({"data": arr_4d_obj}, {"data": arr_4d_obj.copy()})  # Should pass
    with pytest.raises(AssertionError):
        union_numpy_dict({"data": arr_4d_obj}, {"data": arr_4d_obj_diff})

    # --- Test Case 4: Explicit NaN value comparison ---
    # This verifies that our new _deep_equal logic correctly handles NaNs.
    nan_arr = np.array([1.0, np.nan, 3.0])
    dict_nan_1 = {"data": nan_arr}
    dict_nan_2_same = {"data": np.array([1.0, np.nan, 3.0])}  # A new array with same values
    dict_nan_3_different_val = {"data": np.array([1.0, 2.0, 3.0])}
    dict_nan_4_different_pos = {"data": np.array([np.nan, 1.0, 3.0])}

    # NaNs in the same position should be considered equal for merging.
    union_numpy_dict(dict_nan_1, dict_nan_2_same)  # Should pass

    with pytest.raises(AssertionError):
        union_numpy_dict(dict_nan_1, dict_nan_3_different_val)
    with pytest.raises(AssertionError):
        union_numpy_dict(dict_nan_1, dict_nan_4_different_pos)

    # --- Test Case 5: Circular reference handling ---
    # Create two separate, but structurally identical, circular references.
    # This should pass without a RecursionError.
    circ_arr_1 = np.array([None], dtype=object)
    circ_arr_1[0] = circ_arr_1

    circ_arr_2 = np.array([None], dtype=object)
    circ_arr_2[0] = circ_arr_2

    union_numpy_dict({"data": circ_arr_1}, {"data": circ_arr_2})  # Should pass

    # Create a circular reference and a non-circular one.
    # This should fail with an AssertionError because they are different.
    non_circ_arr = np.array([None], dtype=object)

    with pytest.raises(AssertionError):
        union_numpy_dict({"data": circ_arr_1}, {"data": non_circ_arr})


def test_tensor_dict_constructor():
    obs = torch.randn(100, 10)
    act = torch.randn(100, 10, 3)
    data = DataProto.from_dict(tensors={"obs": obs, "act": act})

    assert data.batch.batch_size == torch.Size([100])

    with pytest.raises(AssertionError):
        data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2)

    with pytest.raises(AssertionError):
        data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3)


def test_tensor_dict_make_iterator():
    obs = torch.randn(100, 10)
    labels = [random.choice(["abc", "cde"]) for _ in range(100)]
    dataset = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels})

    data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)
    data_list_1 = []
    for data in data_iter_1:
        data_list_1.append(data)

    data_iter_2 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)
    data_list_2 = []
    for data in data_iter_2:
        data_list_2.append(data)

    for data1, data2 in zip(data_list_1, data_list_2, strict=True):
        assert isinstance(data1, DataProto)
        assert isinstance(data2, DataProto)
        result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"]))
        if not result.item():
            print(data1.batch["obs"])
            print(data2.batch["obs"])
            raise AssertionError()
        non_tensor_result = np.all(np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"]))
        if not non_tensor_result.item():
            print(data1.non_tensor_batch["labels"])
            print(data2.non_tensor_batch["labels"])


def test_reorder():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    labels = ["a", "b", "c", "d", "e", "f"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
    data.reorder(torch.tensor([3, 4, 2, 0, 1, 5]))

    assert torch.all(torch.eq(data.batch["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
    assert np.all(data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
    assert data.meta_info == {"name": "abdce"}


def test_chunk_concat():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    labels = ["a", "b", "c", "d", "e", "f"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})

    with pytest.raises(AssertionError):
        data.chunk(5)

    data_split = data.chunk(2)
    assert len(data_split) == 2
    assert torch.all(torch.eq(data_split[0].batch["obs"], torch.tensor([1, 2, 3])))
    assert np.all(data_split[0].non_tensor_batch["labels"] == np.array(["a", "b", "c"]))
    assert data_split[0].meta_info == {"name": "abdce"}

    assert torch.all(torch.eq(data_split[1].batch["obs"], torch.tensor([4, 5, 6])))
    assert np.all(data_split[1].non_tensor_batch["labels"] == np.array(["d", "e", "f"]))
    assert data_split[1].meta_info == {"name": "abdce"}

    concat_data = DataProto.concat(data_split)
    assert torch.all(torch.eq(concat_data.batch["obs"], data.batch["obs"]))
    assert np.all(concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"])
    assert concat_data.meta_info == data.meta_info


def test_pop():
    obs = torch.randn(100, 10)
    act = torch.randn(100, 3)
    dataset = DataProto.from_dict({"obs": obs, "act": act}, meta_info={"2": 2, "1": 1})
    poped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["2"])

    assert poped_dataset.batch.keys() == {"obs"}
    assert poped_dataset.meta_info.keys() == {"2"}

    assert dataset.batch.keys() == {"act"}
    assert dataset.meta_info.keys() == {"1"}


def test_repeat():
    # Create a DataProto object with some batch and non-tensor data
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})

    # Test interleave=True
    repeated_data_interleave = data.repeat(repeat_times=2, interleave=True)
    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
    expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]

    assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave))
    assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all()
    assert repeated_data_interleave.meta_info == {"info": "test_info"}

    # Test interleave=False
    repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False)
    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
    expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]

    assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave))
    assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all()
    assert repeated_data_no_interleave.meta_info == {"info": "test_info"}


def test_dataproto_pad_unpad():
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})

    from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto

    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=2)
    assert pad_size == 1

    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
    expected_labels = ["a", "b", "c", "a"]

    assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs))
    assert (padded_data.non_tensor_batch["labels"] == expected_labels).all()
    assert padded_data.meta_info == {"info": "test_info"}

    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
    assert torch.all(torch.eq(unpadd_data.batch["obs"], obs))
    assert (unpadd_data.non_tensor_batch["labels"] == labels).all()
    assert unpadd_data.meta_info == {"info": "test_info"}

    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3)
    assert pad_size == 0

    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    expected_labels = ["a", "b", "c"]

    assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs))
    assert (padded_data.non_tensor_batch["labels"] == expected_labels).all()
    assert padded_data.meta_info == {"info": "test_info"}

    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
    assert torch.all(torch.eq(unpadd_data.batch["obs"], obs))
    assert (unpadd_data.non_tensor_batch["labels"] == labels).all()
    assert unpadd_data.meta_info == {"info": "test_info"}

    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7)
    assert pad_size == 4

    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
    expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
    assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs))
    assert (padded_data.non_tensor_batch["labels"] == expected_labels).all()
    assert padded_data.meta_info == {"info": "test_info"}

    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
    assert torch.all(torch.eq(unpadd_data.batch["obs"], obs))
    assert (unpadd_data.non_tensor_batch["labels"] == labels).all()
    assert unpadd_data.meta_info == {"info": "test_info"}


def test_dataproto_fold_unfold():
    from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim

    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})

    data1 = data.repeat(repeat_times=2, interleave=True)

    data2 = fold_batch_dim(data1, new_batch_size=3)

    torch.testing.assert_close(data2.batch["obs"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]))
    assert (data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]]).all()

    data2.reorder(indices=torch.tensor([1, 2, 0]))

    data3 = unfold_batch_dim(data2, batch_dims=2)

    torch.testing.assert_close(data3.batch["obs"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]]))
    assert (data3.non_tensor_batch["labels"] == ["b", "b", "c", "c", "a", "a"]).all()
    assert data3.meta_info == {"info": "test_info"}


def test_torch_save_data_proto():
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
    data.save_to_disk("test_data.pt")
    loaded_data = DataProto.load_from_disk("test_data.pt")

    assert torch.all(torch.eq(loaded_data.batch["obs"], data.batch["obs"]))
    assert (loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]).all()
    assert loaded_data.meta_info == data.meta_info

    import os

    os.remove("test_data.pt")


def test_len():
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = np.array(["a", "b", "c"], dtype=object)
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})

    assert len(data) == 3

    data = DataProto(batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"})

    assert len(data) == 3

    data = DataProto(batch=None, non_tensor_batch={}, meta_info={"info": "test_info"})

    assert len(data) == 0

    data = DataProto(batch=None, non_tensor_batch=None, meta_info={"info": "test_info"})

    assert len(data) == 0


def test_dataproto_index():
    data_len = 100
    idx_num = 10

    obs = torch.randn(data_len, 10)
    labels = [random.choice(["abc", "cde"]) for _ in range(data_len)]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels})
    labels_np = np.array(labels)

    idx_np_int = np.random.randint(0, data_len, size=(idx_num,))
    result_np_int = data[idx_np_int]
    assert result_np_int.batch.keys() == data.batch.keys()
    assert result_np_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()
    assert result_np_int.batch["obs"].shape[0] == idx_num
    assert result_np_int.non_tensor_batch["labels"].shape[0] == idx_num
    assert np.array_equal(result_np_int.batch["obs"].cpu().numpy(), obs[idx_np_int].numpy())
    assert np.array_equal(result_np_int.non_tensor_batch["labels"], labels_np[idx_np_int])

    idx_torch_int = torch.randint(0, data_len, size=(idx_num,))
    result_torch_int = data[idx_torch_int]
    assert result_torch_int.batch.keys() == data.batch.keys()
    assert result_torch_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()
    assert result_torch_int.batch["obs"].shape[0] == idx_num
    assert result_torch_int.non_tensor_batch["labels"].shape[0] == idx_num
    assert np.array_equal(result_torch_int.batch["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())
    assert np.array_equal(result_torch_int.non_tensor_batch["labels"], labels_np[idx_torch_int.cpu().numpy()])

    idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]
    result_list_int = data[idx_list_int]
    assert result_list_int.batch.keys() == data.batch.keys()
    assert result_list_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()
    assert result_list_int.batch["obs"].shape[0] == idx_num
    assert result_list_int.non_tensor_batch["labels"].shape[0] == idx_num
    assert np.array_equal(result_list_int.batch["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy())
    assert np.array_equal(result_list_int.non_tensor_batch["labels"], labels_np[idx_list_int])

    idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)
    result_np_bool = data[idx_np_bool]
    assert result_np_bool.batch.keys() == data.batch.keys()
    assert result_np_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()
    assert result_np_bool.batch["obs"].shape[0] == idx_np_bool.sum()
    assert result_np_bool.non_tensor_batch["labels"].shape[0] == idx_np_bool.sum()
    assert np.array_equal(result_np_bool.batch["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())
    assert np.array_equal(result_np_bool.non_tensor_batch["labels"], labels_np[idx_np_bool])

    idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)
    result_torch_bool = data[idx_torch_bool]
    assert result_torch_bool.batch.keys() == data.batch.keys()
    assert result_torch_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()
    assert result_torch_bool.batch["obs"].shape[0] == idx_torch_bool.sum().item()
    assert result_torch_bool.non_tensor_batch["labels"].shape[0] == idx_torch_bool.sum().item()
    assert np.array_equal(result_torch_bool.batch["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())
    assert np.array_equal(result_torch_bool.non_tensor_batch["labels"], labels_np[idx_torch_bool])

    idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]
    result_list_bool = data[idx_list_bool]
    assert result_list_bool.batch.keys() == data.batch.keys()
    assert result_list_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()
    assert result_list_bool.batch["obs"].shape[0] == sum(idx_list_bool)
    assert result_list_bool.non_tensor_batch["labels"].shape[0] == sum(idx_list_bool)
    assert np.array_equal(result_list_bool.batch["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())
    assert np.array_equal(result_list_bool.non_tensor_batch["labels"], labels_np[idx_list_bool])


def test_old_vs_new_from_single_dict():
    class CustomProto(DataProto):
        """Uses the new, fixed from_single_dict."""

        pass

    class OriginProto(DataProto):
        """Mimics the *old* from_single_dict (always returns a DataProto)."""

        @classmethod
        def from_single_dict(cls, data, meta_info=None, auto_padding=False):
            tensors, non_tensors = {}, {}
            for k, v in data.items():
                if torch.is_tensor(v):
                    tensors[k] = v
                else:
                    non_tensors[k] = v
            # always calls DataProto.from_dict, ignoring `cls`
            return DataProto.from_dict(
                tensors=tensors,
                non_tensors=non_tensors,
                meta_info=meta_info,
                auto_padding=auto_padding,
            )

    sample = {"x": torch.tensor([0])}

    orig = OriginProto.from_single_dict(sample)
    # old behavior: always DataProto, not a CustomOriginProto
    assert type(orig) is DataProto
    assert type(orig) is not OriginProto

    cust = CustomProto.from_single_dict(sample)
    # new behavior: respects subclass
    assert type(cust) is CustomProto


def test_dataproto_no_batch():
    labels = ["a", "b", "c"]
    data = DataProto.from_dict(non_tensors={"labels": labels}, meta_info={"info": "test_info"})
    selected = data.select(non_tensor_batch_keys=["labels"])
    assert (selected.non_tensor_batch["labels"] == labels).all()
    pop_data = data.pop(non_tensor_batch_keys=["labels"])
    assert (pop_data.non_tensor_batch["labels"] == labels).all()
    assert data.non_tensor_batch == {}


def test_sample_level_repeat():
    # Create a DataProto object with some batch and non-tensor data
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})

    # list
    repeated_data_interleave = data.sample_level_repeat(repeat_times=[3, 1, 2])
    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])
    expected_labels_interleave = ["a", "a", "a", "b", "c", "c"]

    assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave))
    assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all()
    assert repeated_data_interleave.meta_info == {"info": "test_info"}

    # torch.tensor
    repeated_data_no_interleave = data.sample_level_repeat(repeat_times=torch.tensor([1, 2, 3]))
    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])
    expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"]

    assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave))
    assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all()
    assert repeated_data_no_interleave.meta_info == {"info": "test_info"}


def test_dataproto_unfold_column_chunks():
    obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
    obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]])

    labels = ["a", "b", "c"]
    data = DataProto.from_dict(
        tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"}
    )
    ret = data.unfold_column_chunks(2, split_keys=["obs1"])

    expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]])
    expect_labels = ["a", "a", "b", "b", "c", "c"]
    assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1))
    assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2))
    assert (ret.non_tensor_batch["labels"] == expect_labels).all()
    assert ret.meta_info == {"name": "abc"}

    obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
    obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]])

    labels = [["a1", "a2"], ["b1", "b2"], ["c1", "c2"]]
    data = DataProto.from_dict(
        tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"}
    )
    ret = data.unfold_column_chunks(2, split_keys=["obs1", "labels"])

    expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]])
    expect_labels = [["a1"], ["a2"], ["b1"], ["b2"], ["c1"], ["c2"]]
    assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1))
    assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2))
    assert (ret.non_tensor_batch["labels"] == expect_labels).all()
    assert ret.meta_info == {"name": "abc"}

    obs1 = torch.tensor(
        [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]]
    )
    obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]])

    labels = ["a", "b", "c"]
    data = DataProto.from_dict(
        tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"}
    )
    ret = data.unfold_column_chunks(2, split_keys=["obs1"])

    expect_obs1 = torch.tensor(
        [
            [[1, 1], [2, 2]],
            [[3, 3], [4, 4]],
            [[5, 5], [6, 6]],
            [[7, 7], [8, 8]],
            [[9, 9], [10, 10]],
            [[11, 11], [12, 12]],
        ]
    )
    expect_obs2 = torch.tensor(
        [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]]
    )
    expect_labels = ["a", "a", "b", "b", "c", "c"]
    assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1))
    assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2))
    assert (ret.non_tensor_batch["labels"] == expect_labels).all()
    assert ret.meta_info == {"name": "abc"}


def test_dataproto_chunk_after_index():
    data_len = 4
    obs = torch.randn(data_len, 4)
    labels = [f"label_{i}" for i in range(data_len)]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abc"})

    # Test with boolean numpy array
    bool_mask = np.array([True, False, True, False])
    selected = data[bool_mask]
    assert isinstance(selected.batch.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch.batch_size)  # int or List[int]

    # Test with integer numpy array
    int_mask = np.array([0, 2])
    selected = data[int_mask]
    assert isinstance(selected.batch.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch.batch_size)

    # Test with boolean list
    list_mask = [True, False, True, False]
    selected = data[list_mask]
    assert isinstance(selected.batch.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch.batch_size)

    # Test with list
    list_mask = [0, 2]
    selected = data[list_mask]
    assert isinstance(selected.batch.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch.batch_size)

    # Test with torch tensor (bool)
    torch_bool_mask = torch.tensor([True, False, True, False])
    selected = data[torch_bool_mask]
    assert isinstance(selected.batch.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch.batch_size)

    # Test with torch tensor (int)
    torch_int_mask = torch.tensor([0, 2])
    selected = data[torch_int_mask]
    assert isinstance(selected.batch.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch.batch_size)


@pytest.mark.skipif(
    parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    labels = ["a", "b", "c", "d", "e", "f"]
    data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
    output = data.to_tensordict()

    assert torch.all(torch.eq(output["obs"], obs)).item()
    assert output["labels"] == labels
    assert output["name"] == "abdce"


def test_serialize_deserialize_single_tensor():
    """Test serialization and deserialization of a single tensor"""
    # Create test tensor
    original_tensor = torch.randn(3, 4, 5)

    # Serialize
    dtype, shape, data = serialize_single_tensor(original_tensor)

    # Deserialize
    reconstructed_tensor = deserialize_single_tensor((dtype, shape, data))

    # Verify results
    assert torch.allclose(original_tensor, reconstructed_tensor)
    assert original_tensor.shape == reconstructed_tensor.shape
    assert original_tensor.dtype == reconstructed_tensor.dtype


def test_serialize_deserialize_tensordict_regular_tensors():
    """Test serialization and deserialization of TensorDict with regular tensors"""
    # Create test data
    batch_size = (5, 3)
    tensor1 = torch.randn(*batch_size, 4)
    tensor2 = torch.randint(0, 10, (*batch_size, 2))

    # Create TensorDict
    original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size)

    # Serialize
    batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)

    # Deserialize
    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))

    # Verify results
    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

    for key in original_tensordict.keys():
        original_tensor = original_tensordict[key]
        reconstructed_tensor = reconstructed_tensordict[key]

        assert torch.allclose(original_tensor, reconstructed_tensor)
        assert original_tensor.shape == reconstructed_tensor.shape
        assert original_tensor.dtype == reconstructed_tensor.dtype


def test_serialize_deserialize_tensordict_nested_tensors():
    """Test serialization and deserialization of TensorDict with nested tensors"""
    # Create nested tensor
    tensor_list = [torch.randn(2, 3), torch.randn(3, 4), torch.randn(1, 5)]
    nested_tensor = torch.nested.as_nested_tensor(tensor_list)

    # Create regular tensor for comparison
    regular_tensor = torch.randn(3, 4, 5)

    # Create TensorDict
    original_tensordict = TensorDict({"nested": nested_tensor, "regular": regular_tensor}, batch_size=(3,))

    # Serialize
    batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)

    # Deserialize
    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))

    # Verify results
    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

    # Verify regular tensor
    original_regular = original_tensordict["regular"]
    reconstructed_regular = reconstructed_tensordict["regular"]

    assert torch.allclose(original_regular, reconstructed_regular)
    assert original_regular.shape == reconstructed_regular.shape
    assert original_regular.dtype == reconstructed_regular.dtype

    # Verify nested tensor
    original_nested = original_tensordict["nested"]
    reconstructed_nested = reconstructed_tensordict["nested"]

    # Check if it's a nested tensor
    assert original_nested.is_nested
    assert reconstructed_nested.is_nested

    # Check layout
    assert original_nested.layout == reconstructed_nested.layout

    # Check each tensor after unbinding
    original_unbind = original_nested.unbind()
    reconstructed_unbind = reconstructed_nested.unbind()

    assert len(original_unbind) == len(reconstructed_unbind)

    for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):
        assert torch.allclose(orig, recon)
        assert orig.shape == recon.shape
        assert orig.dtype == recon.dtype


def test_serialize_deserialize_tensordict_mixed_types():
    """Test serialization and deserialization of TensorDict with mixed tensor types"""
    # Create tensors with different data types
    float_tensor = torch.randn(2, 3).float()
    double_tensor = torch.randn(2, 3).double()
    int_tensor = torch.randint(0, 10, (2, 3)).int()
    long_tensor = torch.randint(0, 10, (2, 3)).long()
    bool_tensor = torch.tensor([[True, False], [False, True]])
    bfloat16_tensor = torch.randn(2, 3).bfloat16()

    # Add fp8 tensor (if available)
    # Note: FP8 is not natively supported in all PyTorch versions
    # We'll check if it's available and conditionally include it
    has_fp8 = hasattr(torch, "float8_e5m2") or hasattr(torch, "float8_e4m3fn")
    if has_fp8:
        try:
            # Try to create an FP8 tensor (implementation may vary)
            # This is a placeholder - actual FP8 support might require specific hardware
            fp8_tensor = torch.randn(2, 3)
            if hasattr(torch, "float8_e5m2"):
                fp8_tensor = fp8_tensor.to(torch.float8_e5m2)
            elif hasattr(torch, "float8_e4m3fn"):
                fp8_tensor = fp8_tensor.to(torch.float8_e4m3fn)
        except Exception:
            has_fp8 = False

    # Create nested tensor
    tensor_list = [
        torch.randn(2, 3),
        torch.randn(3, 4),
    ]
    nested_tensor = torch.nested.as_nested_tensor(tensor_list)

    # Create TensorDict with all available types
    tensordict_data = {
        "float": float_tensor,
        "double": double_tensor,
        "int": int_tensor,
        "long": long_tensor,
        "bool": bool_tensor,
        "bfloat16": bfloat16_tensor,
        "nested": nested_tensor,
    }

    # Conditionally add fp8 tensor if available
    if has_fp8:
        tensordict_data["fp8"] = fp8_tensor

    original_tensordict = TensorDict(
        tensordict_data,
        batch_size=(2,),
    )

    # Serialize
    batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)

    # Deserialize
    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))

    # Verify results
    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

    for key in original_tensordict.keys():
        original_tensor = original_tensordict[key]
        reconstructed_tensor = reconstructed_tensordict[key]

        if original_tensor.is_nested:
            # For nested tensors, check each tensor after unbinding
            original_unbind = original_tensor.unbind()
            reconstructed_unbind = reconstructed_tensor.unbind()

            assert len(original_unbind) == len(reconstructed_unbind)

            for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):
                assert torch.allclose(orig, recon, equal_nan=True)
                assert orig.shape == recon.shape
                assert orig.dtype == recon.dtype
        else:
            # For regular tensors, compare directly
            assert torch.all(original_tensor == reconstructed_tensor)
            assert original_tensor.shape == reconstructed_tensor.shape
            assert original_tensor.dtype == reconstructed_tensor.dtype


def test_serialize_deserialize_tensordict_with_device():
    """Test serialization and deserialization of TensorDict with device information"""
    # Create test data
    batch_size = (2, 3)
    tensor1 = torch.randn(*batch_size, 4)
    tensor2 = torch.randint(0, 10, (*batch_size, 2))

    # Create TensorDict with device information
    device = "cuda" if torch.cuda.is_available() else "cpu"
    original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size, device=device)

    # Serialize
    batch_size_serialized, device_serialized, encoded_items = serialize_tensordict(original_tensordict)

    # Deserialize
    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device_serialized, encoded_items))

    # Verify results
    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
    assert str(original_tensordict.device) == str(reconstructed_tensordict.device)
    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

    for key in original_tensordict.keys():
        original_tensor = original_tensordict[key]
        reconstructed_tensor = reconstructed_tensordict[key]

        assert torch.allclose(original_tensor.cpu(), reconstructed_tensor.cpu())
        assert original_tensor.shape == reconstructed_tensor.shape
        assert original_tensor.dtype == reconstructed_tensor.dtype
