from clus.models.model.basic import BasicModel
from clus.config.master_config import CLUSConfig



class HierarchicalBase(BasicModel) :
    '''
    # Hierarchical model class using flax #
    [included]
    base model's train, eval forward function can be different
    * If you want to use each model separately, use hl and ll model directly
    '''
    def __init__(
            self,
            clus_config: CLUSConfig=None, 
        ) :
        self.clus_config = clus_config

        self.model_config = self.clus_config.model_config
        self.optimizer_config = self.clus_config.optimizer_config
        self.input_config = self.clus_config.input_config

        self.hl_input_config = self.input_config
        self.hl_output_config = self.input_config
        self.hl_model = None

        self.ll_input_config = self.input_config
        self.ll_output_config = self.input_config
        self.ll_model = None


    def loss_fn(self, input_dict) :
        '''
        Must be Implemented by the Hierarchical Manner
        '''
        raise NotImplementedError


    def train_model(self, input_dict) : 
        raise NotImplementedError


    def eval_model(self, input_dict) :
        '''
        forward highlevel and lowlevel model
        ### must get inputs by the processing function
        '''
        hl_input_dict = None
        hl_params = self.hl_model.train_state.params
        hl_out = self.hl_model.forward(hl_params, **hl_input_dict)

        ll_input_dict = None 
        ll_params = self.ll_model.train_state.params
        ll_out = self.ll_model.forward(ll_params, **ll_input_dict)
        
        output_dict = {
            'hl_out' : hl_out,
            'll_out' : ll_out,
        }
        return output_dict 

from clus.models.model.mapping_model import InstructionMappingModel
from clus.models.model.basic import MLPModule
from clus.models.utils.utils import update_rngs
from clus.models.utils.loss import mse
import jax

class HierarchicalSeparated(HierarchicalBase) :
    '''
    # Oracle Hierarchical model class using flax #
    [included]
    * If you want to use each model separately, use hl and ll model directly
    '''
    def __init__(
            self,
            clus_config: CLUSConfig=None, 
        ) :
        super().__init__(clus_config=clus_config)

        self.hl_input_config = self.input_config
        self.hl_output_config = self.input_config
        self.hl_model = InstructionMappingModel()

        self.ll_input_config = self.input_config
        self.ll_output_config = self.input_config
        self.ll_model = MLPModule()

    def loss_fn(self, params, state, batch, rngs=None) :
        '''
        Must be Implemented by the Hierarchical Manner
        '''
        # forward highlevel model
        hl_input_dict = None
        hl_params = self.hl_model.train_state.params
        hl_out = self.hl_model.forward(hl_params, **hl_input_dict)
        hl_labels = batch['hl_labels']
        # forward the lowlevel model
        ll_input_dict = None 
        ll_params = self.ll_model.train_state.params
        ll_out = self.ll_model.forward(ll_params, **ll_input_dict)
        ll_labels = batch['ll_labels']
        # calculate the loss
        reconstruction_loss = mse(ll_out, ll_labels)

        raise NotImplementedError

    def train_model(self, input_dict) : 
        '''
        process input batch and train the model
        '''
        hl_inputs = None
        hl_metric = self.hl_model.train_model(**hl_inputs)

        ll_inputs = None
        ll_metric = self.ll_model.train_model(**ll_inputs)

        metric_dict = {
            'hl_metric' : hl_metric,
            'll_metric' : ll_metric,
        }
        return metric_dict

    def eval_model(self, input_dict) :
        '''
        forward highlevel and lowlevel model
        ### must get inputs by the processing function
        '''
        hl_input_dict = None
        hl_params = self.hl_model.train_state.params
        hl_out = self.hl_model.forward(hl_params, **hl_input_dict)

        ll_input_dict = None 
        ll_params = self.ll_model.train_state.params
        ll_out = self.ll_model.forward(ll_params, **ll_input_dict)
        
        output_dict = {
            'hl_out' : hl_out,
            'll_out' : ll_out,
        }
        return output_dict




if __name__ == '__main__' :
    pass