import dataclasses
import datetime
import pickle
import socket
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Tuple, Union

import pynvml
import torch
import torch.distributed as dist

from torch.distributed import ProcessGroup, ReduceOp, TCPStore


def build_key_size_numel_dict(names, weights):
    max_dim = 6
    size_lst = [0 for _ in range(max_dim) for _ in names]

    offset = 0
    for name, weight in zip(names, weights):
        assert weight.ndim < max_dim, 'you should increase max_dim'
        size = weight.size()
        for i, s in enumerate(size):
            size_lst[i + offset] = s
        offset += max_dim

    key_size = {}
    key_numel = {}
    total_numel = 0
    offset = 0
    for name in names:
        i = 0
        size = []
        numel = 1
        while size_lst[offset + i] > 0:
            this_size = size_lst[offset + i]
            size.append(this_size)
            numel *= this_size
            i += 1
        key_size[name] = size
        key_numel[name] = numel
        total_numel += numel
        offset += max_dim
    return key_size, key_numel, total_numel


def flatten_weights(names, weights):
    key_size, key_numel, total_numel = build_key_size_numel_dict(names, weights)
    datatype = None
    device = weights[0].device
    flatten_weight = torch.cat([weight.contiguous().view(-1) for weight in weights],
                               dim=0).to(device=device)

    return flatten_weight, key_size, key_numel, total_numel
