


import os
from typing import Any, Optional

import numpy as np
import pytest
import torch

from verl.protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto


def _get_data_proto(
    tensors: Optional[dict[str, list[Any]]] = None,
    non_tensors: Optional[dict[str, list[Any]]] = None,
    meta_info: Optional[dict[str, Any]] = None,
) -> DataProto:
    if tensors is None and non_tensors is None:
        tensors = {"obs": [1, 2, 3, 4, 5, 6]}
        non_tensors = {"labels": ["a", "b", "c", "d", "e", "f"]}

    if tensors is not None:
        tensors = {k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v for k, v in tensors.items()}

    if non_tensors is not None:
        non_tensors = {
            k: np.array(v, dtype=object) if not isinstance(v, np.ndarray) else v for k, v in non_tensors.items()
        }

    meta_info = meta_info or {"info": "test_info"}
    return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)


def _assert_equal(data1: DataProto, data2: Optional[DataProto] = None):
    data2 = data2 or _get_data_proto()
    if data1.batch is not None:
        assert data1.batch.keys() == data2.batch.keys()
        for key in data1.batch.keys():
            assert torch.all(data1.batch[key] == data2.batch[key])
    else:
        assert data2.batch is None

    if data1.non_tensor_batch is not None:
        assert data1.non_tensor_batch.keys() == data2.non_tensor_batch.keys()
        for key in data1.non_tensor_batch.keys():
            assert np.all(data1.non_tensor_batch[key] == data2.non_tensor_batch[key])
    else:
        assert data2.non_tensor_batch is None

    assert data1.meta_info == data2.meta_info


def test_tensor_dict_constructor():
    obs = torch.randn(100, 10)
    act = torch.randn(100, 10, 3)
    data = DataProto.from_dict(tensors={"obs": obs, "act": act})
    assert len(data) == 100

    with pytest.raises(AssertionError):
        data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2)

    with pytest.raises(AssertionError):
        data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3)

    labels = np.array(["a", "b", "c"], dtype=object)
    data = DataProto.from_dict(non_tensors={"labels": labels})
    assert len(data) == 3


def test_getitem():
    data = _get_data_proto()
    assert data[0].batch["obs"] == torch.tensor(1)
    assert data[0].non_tensor_batch["labels"] == "a"
    _assert_equal(data[1:3], _get_data_proto({"obs": [2, 3]}, {"labels": ["b", "c"]}))
    _assert_equal(data[[0, 2]], _get_data_proto({"obs": [1, 3]}, {"labels": ["a", "c"]}))
    _assert_equal(data[torch.tensor([1])], _get_data_proto({"obs": [2]}, {"labels": ["b"]}))


def test_select_pop():
    obs = torch.randn(100, 10)
    act = torch.randn(100, 3)
    dataset = _get_data_proto(tensors={"obs": obs, "act": act}, meta_info={"p": 1, "q": 2})
    selected_dataset = dataset.select(batch_keys=["obs"], meta_info_keys=["p"])

    assert selected_dataset.batch.keys() == {"obs"}
    assert selected_dataset.meta_info.keys() == {"p"}
    assert dataset.batch.keys() == {"obs", "act"}
    assert dataset.meta_info.keys() == {"p", "q"}

    popped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["p"])
    assert popped_dataset.batch.keys() == {"obs"}
    assert popped_dataset.meta_info.keys() == {"p"}
    assert dataset.batch.keys() == {"act"}
    assert dataset.meta_info.keys() == {"q"}


def test_chunk_concat_split():
    data = _get_data_proto()
    with pytest.raises(AssertionError):
        data.chunk(5)

    chunked_data = data.chunk(2)

    assert len(chunked_data) == 2
    expected_data = _get_data_proto({"obs": [1, 2, 3]}, {"labels": ["a", "b", "c"]})
    _assert_equal(chunked_data[0], expected_data)

    concat_data = DataProto.concat(chunked_data)
    _assert_equal(concat_data, data)

    splitted_data = data.split(2)
    assert len(splitted_data) == 3
    expected_data = _get_data_proto({"obs": [1, 2]}, {"labels": ["a", "b"]})
    _assert_equal(splitted_data[0], expected_data)


def test_reorder():
    data = _get_data_proto()
    data.reorder(torch.tensor([3, 4, 2, 0, 1, 5]))
    expected_data = _get_data_proto({"obs": [4, 5, 3, 1, 2, 6]}, {"labels": ["d", "e", "c", "a", "b", "f"]})
    _assert_equal(data, expected_data)


@pytest.mark.parametrize("interleave", [True, False])
def test_repeat(interleave: bool):
    data = _get_data_proto({"obs": [1, 2]}, {"labels": ["a", "b"]})
    repeated_data = data.repeat(repeat_times=2, interleave=interleave)
    expected_tensors = {"obs": [1, 1, 2, 2] if interleave else [1, 2, 1, 2]}
    expected_non_tensors = {"labels": ["a", "a", "b", "b"] if interleave else ["a", "b", "a", "b"]}
    _assert_equal(repeated_data, _get_data_proto(expected_tensors, expected_non_tensors))


@pytest.mark.parametrize("size_divisor", [2, 3])
def test_dataproto_pad_unpad(size_divisor: int):
    data = _get_data_proto({"obs": [1, 2, 3]}, {"labels": ["a", "b", "c"]})
    # test size_divisor=2
    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=size_divisor)
    unpadded_data = unpad_dataproto(padded_data, pad_size=pad_size)

    if size_divisor == 2:
        assert pad_size == 1
        expected_tensors = {"obs": [1, 2, 3, 1]}
        expected_non_tensors = {"labels": ["a", "b", "c", "a"]}
        expected_data = _get_data_proto(expected_tensors, expected_non_tensors)
    else:
        assert pad_size == 0
        expected_data = data

    _assert_equal(padded_data, expected_data)
    _assert_equal(unpadded_data, data)


def test_data_proto_save_load():
    data = _get_data_proto()
    data.save_to_disk("test_data.pt")
    loaded_data = DataProto.load_from_disk("test_data.pt")
    os.remove("test_data.pt")
    _assert_equal(data, loaded_data)


def test_union_tensor_dict():
    obs = torch.randn(100, 10)
    data1 = _get_data_proto({"obs": obs, "act": torch.randn(100, 3)})
    data2 = _get_data_proto({"obs": obs, "rew": torch.randn(100)})
    data1.union(data2)

    data1 = _get_data_proto({"obs": obs, "act": torch.randn(100, 3)})
    data2 = _get_data_proto({"obs": obs + 1, "rew": torch.randn(100)})
    with pytest.raises(ValueError):
        data1.union(data2)
