
import functools
import os

import torch
from transformers import (AutoModelForCausalLM,
                          AutoModelForSequenceClassification, AutoTokenizer,
                          OPTForCausalLM)

from _settings import MODEL_PATH



@functools.lru_cache()
def _load_pretrained_model(model_name, device, torch_dtype=torch.float16):
    if model_name == 'llama-8b-instruct':
        model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.float16, trust_remote_code=True)
    elif model_name == 'mistral-7b':
        model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.3", torch_dtype=torch.float16, trust_remote_code=True)
    elif model_name == 'qwen-8b':
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", torch_dtype=torch.float16, trust_remote_code=True)
    elif model_name == 'qwen-4b':
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", torch_dtype=torch.float16, trust_remote_code=True)
    elif model_name == 'qwen-1b':
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-1.7B", torch_dtype=torch.float16, trust_remote_code=True)
    elif model_name == 'qwen-0b':
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.float16, trust_remote_code=True)
    elif model_name == 'roberta-large-mnli':
         model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")#, torch_dtype=torch_dtype)
    model.to(device)
    return model


@functools.lru_cache()
def _load_pretrained_tokenizer(model_name, use_fast=False):
    if model_name == "microsoft/deberta-large-mnli":
        tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    elif model_name == "roberta-large-mnli":
        tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")

    elif model_name == "llama-8b-instruct":
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", cache_dir='/juice2/scr2/syu03', use_fast=use_fast)
        tokenizer.eos_token_id = 2
        tokenizer.bos_token_id = 1
        tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id)
        tokenizer.bos_token = tokenizer.decode(tokenizer.bos_token_id)
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token
    elif model_name == "mistral-7b":
        tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3", cache_dir='/juice2/scr2/syu03', use_fast=use_fast)
        tokenizer.eos_token_id = 2
        tokenizer.bos_token_id = 1
        tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id)
        tokenizer.bos_token = tokenizer.decode(tokenizer.bos_token_id)
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token
    elif model_name == 'qwen-8b':
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    elif model_name == 'qwen-4b':
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B",  use_fast=use_fast)
    elif model_name == 'qwen-1b':
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", use_fast=use_fast)
    elif model_name == 'qwen-0b':
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", use_fast=use_fast)
    return tokenizer