from .PaLM2 import PaLM2
from .Vicuna import Vicuna
from .GPT import GPT
from .Llama import Llama
from .Llama3 import Llama3
from .Qwen import Qwen
from .Gemma import Gemma
import json

def load_json(file_path):
    with open(file_path) as file:
        results = json.load(file)
    return results

def create_model(config_path, args):
    """
    Factory method to create a LLM instance
    """
    config = load_json(config_path)

    provider = config["model_info"]["provider"].lower()
    if provider == 'palm2':
        model = PaLM2(config, args)
    elif provider == 'vicuna':
        model = Vicuna(config, args)
    elif provider == 'gpt':
        model = GPT(config, args)
    elif provider == 'llama':
        model = Llama(config, args)
    elif provider == 'llama3':
        model = Llama3(config, args)
    elif provider == 'qwen':
        model = Qwen(config, args)
    elif provider == 'gemma':
        model = Gemma(config, args)
    else:
        raise ValueError(f"ERROR: Unknown provider {provider}")
    return model