import torch
class RedundentTensor():
    def __init__(self,map,stack=torch.stack):
        self.map=map
        self.value_dict = None
        self.stack=stack
        self.keys=list(map.keys())
        self.unique_keys=self.get_unique_keys(self.keys)
        self.inverse_idx=[self.map[i] for i in self.keys]
        self.inverse_pos=[self.unique_keys.index(self.map[i]) for i in self.keys]

    @staticmethod
    def get_unique_list(l):
        unique_list = list(set(l))
        return unique_list
    
    def get_unique_keys(self,keys=None):
        if keys is None:
            keys=self.keys
        unique_list=self.get_unique_list([self.map[i] for i in keys])
        return unique_list
    
    
    def set(self,values,keys=None):
        if keys is None:
            keys=self.unique_keys
        self.value_dict = dict(zip(keys, values))
        return
    def get(self,keys=None):
        if keys is None:
            keys=self.keys
        r=[]
        for key in keys:
            ukey=self.map[key]
            r.append(self.value_dict[ukey])
        return self.stack(r)
    
    def get_sub(self,keys=None):
        new_map= {k: self.map[k] for k in keys}
        return RedundentTensor(new_map,self.stack)
    

