"""Input types for DB1 of different modalities and tasks"""
from dataclasses import dataclass, asdict
from typing import Optional, Union, List
import torch

@dataclass
class GatoInputBase:
    # 各种任务的通用属性
    position_id: Optional[torch.tensor]
    attention_mask: Optional[torch.tensor]
    loss_mask: Optional[torch.tensor]
    label: Optional[torch.tensor]

    def get_datasize(self):
        '''
        计算当前对象占用内存大小(单位为GB)
        '''
        gb = 0
        for e in [self.position_id, self.attention_mask, self.loss_mask, self.label]:
            if e is not None:
                gb += e.element_size() * e.nelement()
        gb /= 1024 * 1024 * 1024
        return gb

    def to(self, **kwargs):
        '''
        将对象中所有不为 None 的张量都移动到指定的设备上
        '''
        for k, v in asdict(self).items():
            if v is not None and isinstance(v, torch.Tensor):
                setattr(self, k, v.to(**kwargs))

    def apply(self, fn, *args, **kwargs):
        '''
        对对象中所有不为 None 的张量都应用 fn 方法, 并传入 args 和 **kwargs 参数
        '''
        for k, v in asdict(self).items():
            if v is not None:
                setattr(self, k, fn(v, *args, **kwargs))

    def append(self, other):
        '''
        将另一个GatoInputBase对象的所有属性都添加到当前对象的相应属性中
        '''
        assert type(self).__name__ == type(other).__name__
        for k, v in asdict(self).items():
            if v is None:
                continue
            other_v = getattr(other, k)
            # assert v.ndim == other_v.ndim == 2, (k, v.shape, other_v.shape)
            new_v = torch.cat([v, other_v], dim=0)
            # new_v = np.concatenate([v, other_v], axis=0)
            setattr(self, k, new_v)

    @staticmethod
    def merge_into_one(data2merge: List["GatoInputBase"]):
        '''
        将多个GatoInputBase对象合并成一个对象
        在构造 train data 时，这会将所有 GatoInputBase 的各个属性 (都是 1x1024 的 tensor) 组成列表
        '''
        t = data2merge[0]
        t.apply(lambda x: [x])
        for i in range(1, len(data2merge)):
            for k, v in asdict(data2merge[i]).items():
                if v is not None:
                    t_v = getattr(t, k)
                    t_v.append(v)
            # set(t, k, t_v)
        return t


@dataclass
class RLTaskInput(GatoInputBase):
    text_seq: Union[List, torch.tensor]
    tensor_seq: Union[List, torch.tensor]
    obs_idxs: Union[List, torch.tensor]
    seq_len: Union[List, torch.tensor]
    prefix_mask: Union[List, torch.tensor]
