import torch 
import os 
os.environ['TRANSFORMERS_CACHE'] = '../cache/'
from transformers import AutoTokenizer, AutoModelForCausalLM
print("Transformers Cache Directory:", os.environ.get('TRANSFORMERS_CACHE'))
from templates import getLlaMAExtractionStrs, getGemmaExtractionStrs, applyLlamaTemplate, applyGemmaTemplate
from hook import get_op_to_hook_info, get_op_to_hook_info_gemma

##########################################
################## LLMs ##################
##########################################

def getLLaMA3_Instruct_WithInfo():
    model, tokenizer       = getLLaMA3_Instruct()
    template               = applyLlamaTemplate
    parseStrs              = getLlaMAExtractionStrs()
    L_TOTAL                = 32 # for network, not necesserily what we we will be injecting on 
    num_heads              = 32
    head_dim               = 128
    num_groups             = 4
    num_kv                 = num_heads // num_groups
    op_to_hookinfo_model   = get_op_to_hook_info
    return model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model


def getGemma2_Instruct_WithInfo():
    model, tokenizer       = getGemma2_Instruct()
    template               = applyGemmaTemplate
    parseStrs              = getGemmaExtractionStrs()
    L_TOTAL                = 42 # for network, not necesserily what we we will be injecting on
    num_heads              = 16
    head_dim               = 256
    num_groups             = 2
    num_kv                 = num_heads // num_groups
    op_to_hookinfo_model    = get_op_to_hook_info_gemma
    return model, tokenizer, template, parseStrs, L_TOTAL, num_heads, head_dim, num_groups, num_kv, op_to_hookinfo_model


def getLLaMA3_Instruct():
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") # , bos = False)
    model     = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return model, tokenizer 

def getGemma2_Instruct():
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    model     = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it")
    return model, tokenizer