import torch
import numpy as np

def get_id_pos(worker_ids,pos:list):
    '''
    worker_ids=[2,7],pos=[2,5,7]
    '''
    return [pos.index(i) for i in worker_ids]

def flatten(tensors,keep_dim=0):
    vec = []
    for t in tensors:
        t:torch.tensor
        new_shape = t.shape[:keep_dim] + (-1,)
        vec.append(t.view(new_shape)) 
    return torch.cat(vec,-1)

def tensor_tuple_to_form_dict(tensor_tuple,keep_dim=0):
    form_dict={}
    for key in range(len(tensor_tuple)):
        shape=tensor_tuple[key].size()
        new_shape = torch.Size(shape[keep_dim:])
        form_dict[key]=new_shape
    return form_dict

class FlatModelPara():
    def __init__(self, tensors):
        self.tensors=tensors

    def norm(self):
        return torch.norm(self.tensors)
    
    def get_tuple(self,form_dict):
        pointer = 0
        stat_dict={}
        for key in form_dict:
            temp_size=form_dict[key]
            num_param = torch.prod(torch.LongTensor(list(temp_size)))
            stat_dict[key] = self.tensors[pointer:pointer + num_param].view(temp_size)
            pointer += num_param
        return tuple(stat_dict.values()) #python3.7+ 插入顺序


class FlatModelParaS():
    def __init__(self, tensors):
        self.tensors=tensors
        

    def num(self):
        return len(self.tensors)

    def __item__(self,i):
        return FlatModelPara(self.tensors[i])
    
    def get_subset(self,i:torch.tensor):
        return FlatModelParaS(self.tensors[i])

    def norms(self):
        return torch.norm(self.tensors,dim=1)

    
    def mean(self):
        return torch.mean(self.tensors,dim=0)
    
    def get_tuple(self,form_dict):
        pointer = 0
        stat_dict={}
        for key in form_dict:
            temp_size=form_dict[key]
            num_param = torch.prod(torch.LongTensor(list(temp_size)))
            stat_dict[key] = self.tensors[:,pointer:pointer + num_param].view(-1,*temp_size)
            pointer += num_param
        return tuple(stat_dict.values()) #python3.7+ 插入顺序
    
    
def flattenToFlatModelPara(tensors)->FlatModelPara:
    return FlatModelPara(flatten(tensors,0))

def flattenToFlatModelParaS(tensors)->FlatModelParaS:
    return FlatModelParaS(flatten(tensors,1))
    
