class BaseConfig() :
    def __getitem__(self, key) :
        return getattr(self, key)
    
    def __setitem__(self, key, value) :
        setattr(self, key, value)
import numpy as np

class OptimizerConfig(BaseConfig) :
    def __init__(
            self,
            optimizer_cls=None,
            optimizer_kwargs=None,
        ) :
        self.optimizer_cls = optimizer_cls
        self.optimizer_kwargs = optimizer_kwargs


class InputConfig(BaseConfig) :
    '''
    must match with the input of the 
    flax model!!
    '''
    def __init__(
            self,
            x:tuple=None,
            context:tuple=None,
            hidden_states:tuple=None,
        ) :
        # used for mlp based
        self.x = x
        # used for diffusion based
        self.context = context
        self.hidden_states = hidden_states
        # consider the sequence inputs

class ModelConfig(BaseConfig) :
    def __init__(
            self,
            model_cls=None,
            model_kwargs=None,
        ) :
        self.model_cls = model_cls
        self.model_kwargs = model_kwargs

class ExpConfig(BaseConfig) :
    def __init__(
            self,
            phase_epoch:int=10000,
            eval_epoch:int=5000,
            batch_size:int=1024,
            eval_env:bool=True,
            base_path:str='./data/l2ms/ewctest',
            phase_optim:str='re_initialize',
            replay_method:str='random',
            init_model_path:str=None,
        ) :
        self.phase_epoch = phase_epoch
        self.eval_epoch = eval_epoch
        self.batch_size = batch_size
        self.eval_env = eval_env
        self.base_path = base_path
        self.phase_optim = phase_optim
        self.replay_method = replay_method
        self.init_model_path = init_model_path
        
class ScenarioConfig(BaseConfig) :
    def __init__(
            self,
            dataloader_config=None,
            phase_config=None,
            evaluator_config=None,
        ) :
        self.dataloader_config = dataloader_config
        self.phase_config = phase_config
        self.evaluator_config = evaluator_config

from clus.env.offline import *

def get_scenario(options:dict) :
    sc = ScenarioConfig(**{
            'dataloader_cls' : BaseDataloader,
            'dataloader_kwargs' :{
                'skill_embedding_path' : 'data/continual_dataset/evolving_world/mm_lang_embedding.pkl',
                'skill_exclude' : None,
                'semantic_flag' : False, 
            }
        }
    )

    options = {
        'env' : 'kitchen', # kitchen or mmworld or clalfred
        'spec' : 'full', # full, sparce, corrupted ( by data sparcity )
        'unlearn' : False, # True or False 
    }



if __name__ == "__main__" :
    model_config = ModelConfig()
    optimizer_config = OptimizerConfig()
    input_config = InputConfig()
    exp_config = ExpConfig()
    print(model_config['hidden_size'])
    print(optimizer_config['optimizer_cls'])
    print(input_config['x'])
    print(exp_config['phase_epoch'])
    model_config['hidden_size'] = 1024
    print(model_config['hidden_size'])
    print("Done!")