import os
import torch
 

from transformers import MistralForCausalLM, MixtralForCausalLM, MistralConfig, MixtralConfig 



def get_model(args, tokenizer):



    if args.model == "sparse":
        config = MixtralConfig(
                    vocab_size=len(tokenizer),
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    hidden_size=args.hidden_size,
                    intermediate_size=args.hidden_size,
                    num_hidden_layers=args.layers,
                    num_attention_heads=args.heads,
                    num_key_value_heads=args.heads,
                    num_local_experts=args.num_experts,
                    num_experts_per_tok=2,
            )
        
        model = MixtralForCausalLM(config)
    

    elif args.model == "dense":
        config = MistralConfig(
                    vocab_size=len(tokenizer),
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    hidden_size=args.hidden_size,
                    intermediate_size=args.hidden_size,
                    num_hidden_layers=args.layers,
                    num_attention_heads=args.heads,
                    num_key_value_heads=args.heads,
            )
        
 

    return model 



 
