

from clus.models.model.basic import BasicModel

class InstructionMappingModel(BasicModel) : 
    '''
    Basic Mapping Model 
    * if novel Dataset, directly mapping the result
    * if same input and different output
        * it randomly mapping the output
    * existing input and existing output 
        * noting changes
    
    sub_goal ~ M(instruction, input)
    instruction: str = instruction str or embedding
    input: int = sub_goal done count
    sub_goal = sub_goal str or embedding

    '''
    input_config = {
        'instruction' : None,
        'inputs' : None,
        'labels' : None,
    }
    def __init__(
        self,
    ) :
        # like other flax model it behaves like BasicModel
        self.mapping_model = {}
        pass

    def train_model(self,batch) :
        '''
        this prototype for 
        inst = string
        input = int(index of label)
        label = list(string or embedding)
        '''
        if batch['instruction'] not in self.mapping_model :
            self.mapping_model[batch['instruction']] = [batch['labels']]
        else :
            # if batch['labels'] is same with existing labels
            # check if there is same batch['labels']
            if batch['labels'] not in self.mapping_model[batch['instruction']] :
                self.mapping_model[batch['instruction']].append(batch['labels'])

        return 0

    def eval_model(self, batch) :
        '''
        this prototype for 
        inst = string
        input = int(index of label)
        label = list(string or embedding)
        '''
        if batch['instruction'] not in self.mapping_model :
            return None
        else :
            return self.mapping_model[batch['instruction']][batch['inputs']]


        