from clus.models.peftpool.pool_basis import CosineSimilarityModule
import jax.numpy as jnp



class TaskModel() :
    ''' 
    Task model - GC counut based basic taskmodel
    '''
    def __init__(
        self,
        out_dim=512,
    ) :
        self.task_dict = {}
        self.out_dim = out_dim

    def train(
        self,
        key,
        value,
    ) :
        self.task_dict[key] = {
            'meta_data' : None,
            'value' : value
        }

    def __call__(
        self, 
        task_key:str = None,
        task_obs:int = 0,
    ) :
        if task_key in self.task_dict :
            return self.task_dict[task_key]['value'][task_obs]
        else :
            return jnp.zeros((1, self.out_dim))

