import json
from types import SimpleNamespace

class ModelConfig:
    def __init__(self, args):
        self.hyena=args.hyena
        self.no_rope=args.no_rope
        self.dropout=args.dropout
        self.scale_factor=args.scale_factor
        self.interleave_ipa=args.interleave_ipa
        self.prepend_ipa=args.prepend_ipa
        self.oracle=args.oracle
        self.num_layers=args.num_layers
        self.embed_dim=args.embed_dim
        self.mha_heads=args.mha_heads
        self.ipa_heads=args.ipa_heads
        self.ipa_head_dim=args.ipa_head_dim
        self.ipa_qk=args.ipa_qk
        self.ipa_v=args.ipa_v
        self.add_auxiliary_loss=args.add_auxiliary_loss
        self.time_multiplier=args.time_multiplier
        self.abs_pos_emb=args.abs_pos_emb
        self.abs_time_emb=args.abs_time_emb
        self.time_model=args.time_model
        self.latent_dim= args.latent_dim
        self.num_frames = args.num_frames
        self.threshold = args.threshold
        self.grad_checkpointing = False
        self.crop = args.crop
        self.seq_emb = args.seq_emb
        
    def save_config(self, file_path):
        config_dict = self.__dict__
        with open(file_path, 'w') as f:
            json.dump(config_dict, f)
    
    @classmethod
    def load_config(cls, file_path):
        with open(file_path, 'r') as f:
            config_dict = json.load(f)

        args = SimpleNamespace(**config_dict)
        return cls(args)