from argparse import Namespace
from typing import Any, Dict, Tuple

from huggingface_hub import login

from lm_understanding.prompting import (Completer, SingleWordAnswerPrompter,
                                        ThrottledCompleter)
from lm_understanding.prompting.prompting import FreeResponsePrompter, Prompter

from .gpt import CHATGPT_MODELS, GPT3, ChatCompleter
from .hf_model import HF_MODELS, LLAMA_MODELS, HFModel

HUGGINGFACE_TOKEN_PATH = 'huggingface_token.txt'

def make_full_model_name(model_name: str) -> str:
    if model_name.endswith('-cot'):
        model_name = model_name[:-4]
    model_prefixes = [model_name.split('/')[0] for model_name in HF_MODELS if '/' in model_name]
    for repo in model_prefixes:
        expanded_name = f'{repo}/{model_name}'
        if expanded_name in HF_MODELS:
            return expanded_name
    return model_name


def make_prompter(model_config: Namespace, free_response: bool) -> Prompter:
    prompt_text = model_config.free_response_prompt if free_response else model_config.prompt
    prompt_function = lambda q: prompt_text.replace('[question]', q.text)
    if model_config.cot:
        prompter = SingleWordAnswerPrompter(prompt_function, model_config.score_prefix)
    elif free_response:
        prompter = FreeResponsePrompter(prompt_function)
    else:
        prompter = SingleWordAnswerPrompter(prompt_function, '')
    return prompter


def make_model_kwargs(model_config: Namespace, model_name: str, free_response: bool) -> Dict:
    kwargs: Dict[str, Any] = dict(model_name=model_name)
    black_box = model_name in CHATGPT_MODELS
    new_tokens_key = 'max_tokens' if model_name in CHATGPT_MODELS else 'max_new_tokens'
    if free_response:
        kwargs['temperature'] = 0.0
        kwargs['multiple_choice'] = False
        kwargs[new_tokens_key] = 1024
    elif model_config.cot:
        kwargs['temperature'] = 1.0
        kwargs['multiple_choice'] = False
        kwargs[new_tokens_key] = 1024
    else:
        kwargs['temperature'] = 1.0 if black_box else 0.0
        kwargs[new_tokens_key] = 1
        kwargs['multiple_choice'] = True
    return kwargs


def make_chat_kwargs(model_name: str, multiple_choice: bool) -> Dict:
    kwargs = dict()
    if model_name == 'gpt-3.5-turbo':
        kwargs = dict(model_name=model_name, max_concurrent_tasks=20, timeout_seconds=(10 if multiple_choice else 40), max_tries=5)
    elif model_name == 'gpt-4':
        kwargs = dict(model_name=model_name, max_concurrent_tasks=10, timeout_seconds=120, max_tries=5)
    else:
        raise NotImplementedError()
    return kwargs


def load_model(model_config: Namespace, free_response: bool = False) -> Tuple[Completer, Prompter]:
    model_name = make_full_model_name(model_config.name)
    prompter = make_prompter(model_config, free_response)
    model_kwargs = make_model_kwargs(model_config, model_name, free_response)
    if model_name in LLAMA_MODELS:
        with open(HUGGINGFACE_TOKEN_PATH, 'r') as f: token = f.read()
        login(token=token)
    if model_name in HF_MODELS:
        model = HFModel(**model_kwargs)
        model = ThrottledCompleter(model, completion_batch_size=model_config.batch_size)
    elif model_name in CHATGPT_MODELS:
        model_kwargs.update(make_chat_kwargs(model_name, model_kwargs['multiple_choice']))
        model = ChatCompleter(**model_kwargs)  # type: ignore
    elif model_name in GPT3.model_names():
        model = GPT3(**model_kwargs)
        model = ThrottledCompleter(model, retry_pause_seconds=30, completion_batch_size=20)
    else:
        raise NotImplementedError(f'{model_name} is not a valid model')
    return model, prompter