"""Utilities for going from a list of arrays/Tensors to a single array/Tensor and back again."""
import dataclasses
from typing import List, Sequence, Tuple, Union

import numpy as np
import torch


###############################################################################


@dataclasses.dataclass
class FlatPacker:
    # Will be converted to Tuple[torch.Size, ...] in __post_init__.
    shapes: Sequence[Union[torch.Size, Sequence[int]]]
    
    def __post_init__(self):
        # Ensure consistent type.
        self.shapes = tuple(torch.Size(s) for s in self.shapes)

        self._sizes = [s.numel() for s in self.shapes]

        self._offsets = [0]
        for s in self._sizes[:-1]:
            self._offsets.append(self._offsets[-1] + s)

        self.flat_size = self._offsets[-1] + self._sizes[-1]

    def get_range_for_tensor_by_index(self, tensor_index: int) -> Tuple[int, int]:
        start = self._offsets[tensor_index]
        end = start + self._sizes[tensor_index]
        return start, end

    def unpack_vector(self, x: torch.Tensor) -> List[torch.Tensor]:
        assert len(x.shape) == 1, 'Must be vector.'
        assert x.shape[0] == self.flat_size, 'Vector has invalid size.'

        ret = []
        for i, shape in enumerate(self.shapes):
            start, end = self.get_range_for_tensor_by_index(i)
            ret.append(x[start:end].view(shape))

        return ret
