from dataclasses import dataclass
from typing import List

import torch


@dataclass
class CpuTensorChain:
    """
    Keeps a list of Tensors around (on the CPU and via non-blocking copies) and converts them to a single
    tensor on demand.
    """

    tensors: List[torch.Tensor]

    def append(self, tensor: torch.Tensor):
        self.tensors.append(tensor.detach().to(device="cpu", non_blocking=True))

    def consolidate(self):
        if len(self.tensors) < 1:
            return

        self.tensors = [torch.cat(self.tensors)]

    def get(self):
        self.consolidate()
        return self.tensors[0] if self.tensors else None

    def reset(self):
        self.tensors = []

    @staticmethod
    def create():
        return CpuTensorChain([])

    def __len__(self):
        return sum(len(tensor) for tensor in self.tensors)
