# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement base data transfer protocol between any two functions, modules.
We can subclass Protocol to define more detailed batch info with specific keys
"""

from dataclasses import dataclass, field
import numpy as np

import torch
import torch.distributed



__all__ = ["Agent_DataProto"]

from verl import DataProto

def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
    if len(list_of_dict) == 0:
        return {}
    keys = list_of_dict[0].keys()
    output = {key: [] for key in keys}
    for data in list_of_dict:
        for key, item in data.items():
            assert key in output
            output[key].append(item)
    return output

@dataclass
class Agent_DataProto(DataProto):
    @staticmethod
    def concat_array(data: list["DataProto"], array_keys=[]) -> "DataProto":
        """Concat a list of DataProto. The batch is concatenated among dim=0.
        The meta_info is assumed to be identical and will use the first one.

        Args:
            data (List[DataProto]): list of DataProto

        Returns:
            DataProto: concatenated DataProto
        """
        batch_lst = []
        for batch in data:
            batch_lst.append(batch.batch)
        new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None

        non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
        for key, val in non_tensor_batch.items():
            if isinstance(val[0], np.ndarray):
                val_extend = []
                for val_item in val:
                    for itk in val_item:
                        val_extend.append(itk)
                non_tensor_batch[key] = np.array(val_extend, dtype=object)
            else:
                non_tensor_batch[key] = np.array(val, dtype=object)

        cls = type(data[0]) if len(data) > 0 else DataProto
        return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)

    @staticmethod
    def split_patial_rollout(data_proto: "DataProto", filter_mask) -> tuple["DataProto", "DataProto"]:
        """
        Split a DataProto into two based on a boolean mask.

        Args:
            data_proto: The DataProto to split
            filter_mask: Boolean tensor/array where True values go to the first DataProto

        Returns:
            Tuple[DataProto, DataProto]: First DataProto with items where mask is True,
                                        Second DataProto with items where mask is False
        """
        # Convert to tensor if it's a list or numpy array
        if isinstance(filter_mask, list):
            filter_mask = torch.tensor(filter_mask, dtype=torch.bool)
        elif isinstance(filter_mask, np.ndarray):
            filter_mask = torch.from_numpy(filter_mask)

        # Create inverse mask
        inverse_mask = ~filter_mask

        # Split into two DataProtos
        first_proto = data_proto.select_idxs(filter_mask)
        second_proto = data_proto.select_idxs(inverse_mask)
        return first_proto, second_proto