import os
from transformer_models import (
        GPTNeoXAlibiForCausalLM,
        GPTNeoXNoPEForCausalLM,
        )
from transformers import  GPTNeoXForCausalLM, GPTNeoXConfig

# may want to edit eos/bos token hard coding

# Code adapted from Jelassi et al. "Repeat After Me: Transformers are 
# Better than State Space Models at Copying"
# https://github.com/sjelassi/transformers_ssm_copy/tree/main

def get_model(args, vocab_size):
    
    if args.model in ["T_nope","T_rope","T_alibi"]:
        config = GPTNeoXConfig(
                    bos_token_id=0, #may need tuning
                    eos_token_id=0,
                    hidden_size=args.n_hid,
                    intermediate_size=args.n_hid*4,
                    num_attention_heads=args.heads,
                    num_hidden_layers=args.n_layers,
                    vocab_size=vocab_size,
        )
    else:
        raise NotImplementedError
    
    if args.model=="T_rope":
        model = GPTNeoXForCausalLM(config)
    elif args.model=="T_nope":
        model = GPTNeoXNoPEForCausalLM(config)
    elif args.model=="T_alibi":
        model = GPTNeoXAlibiForCausalLM(config)
        
    return model
