from __future__ import annotations

import torch

@torch.no_grad()
def flatten_named_tensor(
    named_tensor: dict[str, torch.Tensor],
    restore_info_required=True,
)->tuple[torch.Tensor, list[tuple[str, torch.Size]]]:
    flat_tensor = []
    for pname in named_tensor.keys():
        flat_tensor.append(named_tensor[pname].view(-1).float())
    flat_tensor = torch.hstack(flat_tensor)

    if restore_info_required:
        restore_info = []
        for pname in named_tensor.keys():
            restore_info.append((pname, named_tensor[pname].size()))
    else:
        restore_info = None

    return flat_tensor, restore_info

@torch.no_grad()
def restore_tensor(
    flat_tensor: torch.Tensor,
    restore_info: list[tuple[str, torch.Size]],
) -> dict[str, torch.Tensor]:
    '''
    Args:
        flat_update: (D, )
    '''
    restored_tensor = {}
    idx = 0
    for (name, size) in restore_info:
        n_components = size.numel()
        reshaped = flat_tensor[idx: idx + n_components].view(size)
        if 'num_batches_tracked' in name:
            restored_tensor[name] = reshaped.long()
        else:
            restored_tensor[name] = reshaped
        idx = idx + n_components
    return restored_tensor

@torch.no_grad()
def flatten_named_tensors(
    named_tensors: list[dict[str, torch.Tensor]],
)->tuple[torch.Tensor, list[tuple[str, torch.Size]]]:
    flat_tensors = []
    restore_info = []
    for named_tensor in named_tensors:
        restore_info_required = (len(restore_info) == 0)
        if restore_info_required:
            flat_tensor, restore_info = flatten_named_tensor(named_tensor=named_tensor, restore_info_required=restore_info_required)
        else:
            flat_tensor, _ = flatten_named_tensor(named_tensor=named_tensor, restore_info_required=restore_info_required)
        flat_tensors.append(flat_tensor)
    flat_tensors = torch.vstack(flat_tensors)

    return flat_tensors, restore_info

@torch.no_grad()     
def restore_tensors(
    flat_tensors: torch.Tensor, restore_info: list[tuple[str, torch.Size]],
)->list[dict[str, torch.Tensor]]:
    '''
    Args:
        flat_tensors:    (N, D)
    '''
    restored_tensors = []
    for flat_tensor in flat_tensors:
        restored_tensor = restore_tensor(flat_tensor, restore_info)
        restored_tensors.append(restored_tensor)
    return restored_tensors