from .comlib import *


def flatten_to_vec(tup):
    vec = []
    for t in tup:
        vec.append(t.view(-1))   
    return torch.cat(vec)

def unflatten_to_stat_dict_with_form_dict(vec,form_dict):
    pointer = 0
    stat_dict={}
    for key in form_dict:
        temp_size=form_dict[key]
        num_param = torch.prod(torch.LongTensor(list(temp_size)))
        # LOGGER.debug(f"key:{key},temp_size:{temp_size},num_param:{num_param}")
        # LOGGER.debug(vec[pointer:pointer + num_param].view(*temp_size))
        stat_dict[key] = vec[pointer:pointer + num_param].view(temp_size)
        pointer += num_param
    return stat_dict

def unflatten_to_tuple_with_form_dict(vec,form_dict):
    stat_dict=unflatten_to_stat_dict_with_form_dict(vec,form_dict)
    return tuple(stat_dict.values()) #python3.7+ 插入顺序