# Copyright 2024 Anonymous Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Replace DataProto with raw TensorDict
"""

import copy
import random

import numpy as np
import pytest
import torch

from verl.utils import tensordict_utils as tu


def test_union_tensor_dict():
    obs = torch.randn(100, 10)

    meta_info1 = {"top_p": 0.8}
    meta_info2 = {"top_p": 0.9}
    data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
    data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100}

    data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}

    data1 = tu.get_tensordict(tensor_dict=data1)
    data2 = tu.get_tensordict(tensor_dict=data2)
    data_with_copied_obs = tu.get_tensordict(data_with_copied_obs)

    tu.union_tensor_dict(data1, data2)
    with pytest.raises(AssertionError):
        # conflict in tensor values
        tu.union_tensor_dict(data1, data_with_copied_obs)

    data1 = tu.assign_non_tensor_dict(data1, meta_info1)
    tu.union_tensor_dict(data1, data2)  # works ok

    data2 = tu.assign_non_tensor_dict(data2, meta_info2)

    with pytest.raises(AssertionError):
        # conflict in NonTensorData
        tu.union_tensor_dict(data1, data2)

    data1.pop("top_p")
    data2.pop("top_p")

    data2["data_sources"][0] = "math"
    with pytest.raises(AssertionError):
        # conflict in NonTensorData
        tu.union_tensor_dict(data1, data2)


def test_tensor_dict_constructor():
    obs = torch.ones(100, 10)
    act = torch.zeros(100, 10, 3)
    data_source = ["gsm8k"] * 100
    non_tensor_dict = {"name": "abdce"}

    data = tu.get_tensordict(
        tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict
    )

    assert data.batch_size == torch.Size([100])

    # test slicing
    assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item()
    assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item()
    assert data[0]["data_source"] == "gsm8k"

    assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item()
    assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item()
    assert data[0:2]["data_source"] == ["gsm8k"] * 2

    # test non tensor data
    assert data["name"] == "abdce"


def test_tensordict_with_images():
    # each sample contains a sequence with multiple images of different sizes
    vocab_size = 128
    a = torch.randint(low=0, high=vocab_size, size=(11,))
    b = torch.randint(low=0, high=vocab_size, size=(13,))
    input_ids = [a, b]
    input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)

    # must be numpy
    # TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested
    a_images = [
        torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
        torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
    ]
    b_images = [
        torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
        torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
        torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
    ]

    images = [a_images, b_images]

    data = tu.get_tensordict({"input_ids": input_ids, "images": images})

    assert np.all(np.equal(data[0]["images"][0], a_images[0]))
    assert torch.all(torch.eq(data[0]["input_ids"], a))


def test_tensordict_with_packing():
    vocab_size = 128
    a = torch.randint(low=0, high=vocab_size, size=(11,))
    b = torch.randint(low=0, high=vocab_size, size=(13,))
    input_ids = [a, b]
    input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)

    data = tu.get_tensordict({"input_ids": input_ids})

    # test cu_seqlens
    cu_seqlens = torch.tensor([0, 11, 24])
    assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets()))

    # test index
    assert torch.all(torch.eq(data["input_ids"][0], a))
    assert torch.all(torch.eq(data["input_ids"][1], b))

    assert torch.all(torch.eq(data[0]["input_ids"], a))
    assert torch.all(torch.eq(data[1]["input_ids"], b))

    data_lst = data.chunk(2)

    assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
    assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))


def test_tensordict_eq():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
    non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
    data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)

    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
    non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
    data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)

    tu.assert_tensordict_eq(data, data1)

    data2 = copy.deepcopy(data1)
    data2["obs"][0] += 1

    with pytest.raises(AssertionError):
        tu.assert_tensordict_eq(data, data2)

    data2 = copy.deepcopy(data1)
    data2["data_sources"][0] = "math"

    with pytest.raises(AssertionError):
        tu.assert_tensordict_eq(data, data2)

    data2 = copy.deepcopy(data1)
    data2["train_sample_kwargs"]["top_p"] = 0.9

    with pytest.raises(AssertionError):
        tu.assert_tensordict_eq(data, data2)


def test_tensor_dict_make_iterator():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
    non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
    dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)

    dataloader = tu.make_iterator(
        dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False}
    )

    expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]]

    i = 0

    for d in dataloader:
        tu.assert_tensordict_eq(d, expected_tensor_dict[i])
        i += 1

    data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
    data_list_1 = []
    for data in data_iter_1:
        data_list_1.append(data)

    data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
    data_list_2 = []
    for data in data_iter_2:
        data_list_2.append(data)

    for data1, data2 in zip(data_list_1, data_list_2, strict=True):
        tu.assert_tensordict_eq(data1, data2)


def test_reorder():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    labels = ["a", "b", "c", "d", "e", "f"]
    non_tensor_dict = {"name": "abdce"}

    data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict)
    data = data[torch.tensor([3, 4, 2, 0, 1, 5])]

    assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
    assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
    assert data["name"] == "abdce"


def test_chunk_concat():
    obs = torch.tensor([1, 2, 3, 4, 5, 6])
    labels = ["a", "b", "c", "d", "e", "f"]
    data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"})

    data_split = data.tensor_split(indices_or_sections=5, dim=0)

    expected_idx_lst = [[0, 1], [2], [3], [4], [5]]

    for d, expected_idx in zip(data_split, expected_idx_lst, strict=False):
        tu.assert_tensordict_eq(d, data[expected_idx])

    data_split = data.chunk(2)
    assert len(data_split) == 2
    assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3])))
    assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"]))
    assert data_split[0]["name"] == "abcde"

    assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6])))
    assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"]))
    assert data_split[1]["name"] == "abcde"

    concat_data = torch.cat(data_split, dim=0)
    assert torch.all(torch.eq(concat_data["obs"], data["obs"]))
    assert np.all(concat_data["labels"] == data["labels"])
    assert concat_data["name"] == data["name"]


def test_pop():
    obs = torch.randn(100, 10)
    act = torch.randn(100, 3)
    dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})

    poped_dataset = tu.pop(dataset, keys=["obs", "2"])

    assert poped_dataset.batch_size[0] == 100

    assert poped_dataset.keys() == {"obs", "2"}

    assert dataset.keys() == {"act", "1"}


def test_repeat():
    # Create a DataProto object with some batch and non-tensor data
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})

    # Test interleave=True
    repeated_data_interleave = data.repeat_interleave(repeats=2)
    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
    expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]

    assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
    assert repeated_data_interleave["labels"] == expected_labels_interleave
    assert repeated_data_interleave["info"] == "test_info"

    # Test interleave=False
    repeated_data_no_interleave = data.repeat(2)
    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
    expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]

    assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
    assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
    assert repeated_data_no_interleave["info"] == "test_info"


def test_dataproto_pad_unpad():
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})

    padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2)

    assert pad_size == 1

    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
    expected_labels = ["a", "b", "c", "a"]

    assert torch.all(torch.eq(padded_data["obs"], expected_obs))
    assert padded_data["labels"] == expected_labels
    assert padded_data["info"] == "test_info"

    unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
    assert torch.all(torch.eq(unpadd_data["obs"], obs))
    assert unpadd_data["labels"] == labels
    assert unpadd_data["info"] == "test_info"

    padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3)
    assert pad_size == 0

    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    expected_labels = ["a", "b", "c"]

    assert torch.all(torch.eq(padded_data["obs"], expected_obs))
    assert padded_data["labels"] == expected_labels
    assert padded_data["info"] == "test_info"

    unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
    assert torch.all(torch.eq(unpadd_data["obs"], obs))
    assert unpadd_data["labels"] == labels
    assert unpadd_data["info"] == "test_info"

    padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7)
    assert pad_size == 4

    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
    expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
    assert torch.all(torch.eq(padded_data["obs"], expected_obs))
    assert padded_data["labels"] == expected_labels
    assert padded_data["info"] == "test_info"

    unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
    assert torch.all(torch.eq(unpadd_data["obs"], obs))
    assert unpadd_data["labels"] == labels
    assert unpadd_data["info"] == "test_info"


def test_torch_save_data_proto():
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]
    data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})

    filename = "test_data.pt"
    torch.save(data, filename)
    loaded_data = torch.load(filename, weights_only=False)

    assert torch.all(torch.eq(loaded_data["obs"], data["obs"]))
    assert loaded_data["labels"] == data["labels"]
    assert loaded_data["info"] == data["info"]

    import os

    os.remove(filename)


def test_len():
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = np.array(["a", "b", "c"], dtype=object)

    data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
    assert len(data) == 3

    data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
    assert len(data) == 3

    data_item = data[0]
    assert len(data_item) == 0

    data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"})
    assert len(data) == 0


def test_dataproto_index():
    data_len = 100
    idx_num = 10

    obs = torch.randn(data_len, 10)
    labels = [random.choice(["abc", "cde"]) for _ in range(data_len)]

    data = tu.get_tensordict({"obs": obs, "labels": labels})

    labels_np = np.array(labels)

    idx_np_int = np.random.randint(0, data_len, size=(idx_num,))
    result_np_int = data[idx_np_int]
    assert result_np_int.keys() == data.keys()
    assert result_np_int["obs"].shape[0] == idx_num
    assert len(result_np_int["labels"]) == idx_num
    assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy())
    assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int])

    idx_torch_int = torch.randint(0, data_len, size=(idx_num,))
    result_torch_int = data[idx_torch_int]
    assert result_torch_int.keys() == data.keys()
    assert result_torch_int["obs"].shape[0] == idx_num
    assert len(result_torch_int["labels"]) == idx_num
    assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())
    assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()])

    idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]
    result_list_int = data[idx_list_int]
    assert result_list_int.keys() == data.keys()
    assert result_list_int["obs"].shape[0] == idx_num
    assert len(result_list_int["labels"]) == idx_num
    assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy())
    assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int])

    # idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)
    # result_np_bool = data[idx_np_bool]
    # assert result_np_bool.keys() == data.keys()
    # assert result_np_bool["obs"].shape[0] == idx_np_bool.sum()
    # assert len(result_np_bool["labels"]) == idx_np_bool.sum()
    # assert np.array_equal(result_np_bool["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())
    # assert np.array_equal(result_np_bool["labels"], labels_np[idx_np_bool])

    idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)
    result_torch_bool = data[idx_torch_bool]
    assert result_torch_bool.keys() == data.keys()
    assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item()
    assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item()
    assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())
    assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool])

    # idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]
    # result_list_bool = data[idx_list_bool]
    # assert result_list_bool.keys() == data.keys()
    # assert result_list_bool["obs"].shape[0] == sum(idx_list_bool)
    # assert len(result_list_bool["labels"]) == sum(idx_list_bool)
    # assert np.array_equal(result_list_bool["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())
    # assert np.array_equal(result_list_bool["labels"], labels_np[idx_list_bool])


def test_select():
    obs = torch.randn(100, 10)
    act = torch.randn(100, 3)
    dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})

    subset = dataset.select("obs", "2")

    assert torch.all(torch.eq(subset["obs"], dataset["obs"]))
    assert subset["2"] == dataset["2"]
    assert "act" not in subset.keys()
    assert "1" not in subset.keys()


def test_dataproto_no_batch():
    labels = ["a", "b", "c"]
    data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"})
    selected = data.select("labels")

    assert selected["labels"] == labels
    pop_data = tu.pop(data, keys=["labels"])
    assert pop_data["labels"] == labels
    assert "labels" not in data


def test_sample_level_repeat():
    # Create a DataProto object with some batch and non-tensor data
    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
    labels = ["a", "b", "c"]

    data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})

    # list
    repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2]))
    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])
    expected_labels_interleave = ["a", "a", "a", "b", "c", "c"]

    assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
    assert repeated_data_interleave["labels"] == expected_labels_interleave
    assert repeated_data_interleave["info"] == "test_info"

    # torch.tensor
    repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3]))
    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])
    expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"]

    assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
    assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
    assert repeated_data_no_interleave["info"] == "test_info"


def test_dataproto_chunk_after_index():
    data_len = 4
    obs = torch.randn(data_len, 4)
    labels = [f"label_{i}" for i in range(data_len)]

    data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"})
    # Test with boolean numpy array
    bool_mask = torch.tensor([True, False, True, False])
    selected = data[bool_mask]
    assert isinstance(selected.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch_size)  # int or List[int]

    # Test with integer numpy array
    int_mask = torch.tensor([0, 2])
    selected = data[int_mask]
    assert isinstance(selected.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch_size)

    # Test with boolean list
    list_mask = [True, False, True, False]
    selected = data[list_mask]
    assert isinstance(selected.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch_size)

    # Test with list
    list_mask = [0, 2]
    selected = data[list_mask]
    assert isinstance(selected.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch_size)

    # Test with torch tensor (bool)
    torch_bool_mask = torch.tensor([True, False, True, False])
    selected = data[torch_bool_mask]
    assert isinstance(selected.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch_size)

    # Test with torch tensor (int)
    torch_int_mask = torch.tensor([0, 2])
    selected = data[torch_int_mask]
    assert isinstance(selected.batch_size, torch.Size)
    assert all(isinstance(d, int) for d in selected.batch_size)
