from typing import Any, Dict, List

import einops
import torch
from transformers import AutoModelForCausalLM  # type: ignore
from transformers import AutoModelForSeq2SeqLM  # type: ignore
from transformers import AutoTokenizer  # type: ignore
from transformers import BitsAndBytesConfig

from lm_understanding.prompting import Completer, Completion

VICUNA_MODELS = [
    'TheBloke/stable-vicuna-13B-HF',
    'lmsys/vicuna-7b-v1.3',
    'lmsys/vicuna-13b-v1.3',
]

DOLLY_MODELS = [
    "databricks/dolly-v2-12b",
    "databricks/dolly-v2-3b"
]

MPT_MODELS = [
    'mosaicml/mpt-7b-instruct'
]

KOALA_MODELS = [
    'TheBloke/koala-13B-HF',
    'TheBloke/koala-7B-HF'
]

T5_MODELS = [
    "google/flan-t5-small",
    "google/flan-t5-base",
    "google/flan-t5-large",
    "google/flan-t5-xl",
    "google/flan-t5-xxl",
    'google/flan-ul2',
    'declare-lab/flan-alpaca-base',
    'declare-lab/flan-alpaca-xl',
    'declare-lab/flan-alpaca-gpt4-xl',
    'declare-lab/flan-alpaca-xxl',
    '0-hero/flan-alpaca-ul2'
]

FALCON_MODELS = [
    'tiiuae/falcon-7b-instruct',
    'tiiuae/falcon-40b-instruct'
]

GPT2_MODELS = ['gpt2-xl']

LLAMA_MODELS = [
    'meta-llama/Llama-2-7b-hf',
    'meta-llama/Llama-2-13b-hf',
    'meta-llama/Llama-2-7b-chat-hf',
    'meta-llama/Llama-2-13b-chat-hf'
]

HF_MODELS = VICUNA_MODELS + DOLLY_MODELS + KOALA_MODELS + T5_MODELS + FALCON_MODELS + GPT2_MODELS + MPT_MODELS + LLAMA_MODELS


def model_class(model_name: str):
    if model_name in T5_MODELS:
        return AutoModelForSeq2SeqLM
    return AutoModelForCausalLM


def create_tokenizer(model_name: str):
    padding_side = 'right' if model_class(model_name) == AutoModelForSeq2SeqLM else 'left'
    use_fast = not model_name in VICUNA_MODELS
    add_prefix_space = model_name in GPT2_MODELS
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side, add_prefix_space=add_prefix_space, use_fast=use_fast)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def get_input_ids(options, tokenizer) -> List[int]:
    encoded_options = [tokenizer.encode(o, add_special_tokens=False)[-1] for o in options]
    return list(set(encoded_options))

    
def clean_model_output(prompt: str, text: str) -> str:
    if text.startswith(prompt):
        return text[len(prompt):].strip()
    return text


def total_option_answer_prob(model_output, option_input_ids) -> torch.Tensor:
    scores = model_output['scores'][0]  # type: ignore
    base_probs = scores.softmax(dim=1)
    return base_probs[:, option_input_ids].sum(dim=1)


def yes_no_answer_logprobs(model_output, option_input_ids) -> torch.Tensor:
    all_logits = model_output['scores'][0]  # type: ignore
    option_logits = all_logits[:, option_input_ids]
    option_probs = option_logits.softmax(dim=1)
    yes_no_probs = einops.rearrange(option_probs, 'b (v p) -> b v p', v=2).sum(dim=2)
    return yes_no_probs.log()


class HFModel(Completer):
    def __init__(self, model_name: str, multiple_choice: bool, **kwargs) -> None:
        self.generation_kwargs = kwargs
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.model_name = model_name
        self._multiple_choice = multiple_choice
        if not multiple_choice:
            nf4_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
            self.model = model_class(model_name).from_pretrained(model_name, device_map="auto", trust_remote_code=True, quantization_config=nf4_config)
        else:
            self.model = model_class(model_name).from_pretrained(model_name, device_map="auto", trust_remote_code=True)
        self.tokenizer = create_tokenizer(model_name)
        self.multiple_choice_option_input_ids = self._mc_option_input_ids()
    
    @property
    def multiple_choice(self):
        return self._multiple_choice

    @property
    def name(self):
        if '/' not in self.model_name:
            return self.model_name
        return self.model_name.split('/')[1]

    @property
    def info(self) -> Dict[str, Any]:
        return dict(
            model=self.model_name,
            multiple_choice=self.multiple_choice,
            **self.generation_kwargs
        )

    def _mc_option_input_ids(self):
        yes_options = ['Yes', 'yes', ' Yes', ' yes', "'Yes", "'yes"]
        no_options = ['No', 'no', ' No', ' no', "'No", "'no"]
        yes_input_ids = get_input_ids(yes_options, self.tokenizer)
        no_input_ids = get_input_ids(no_options, self.tokenizer)
        assert len(yes_input_ids) == len(no_input_ids)
        return yes_input_ids + no_input_ids
    
    def _get_outputs(self, prompts: List[str], generation_kwargs: Dict):
        inputs = self.tokenizer(prompts, return_tensors='pt', padding='longest')
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask", None)
        return self.model.generate(
            input_ids=input_ids.to(self.model.device),
            attention_mask=attention_mask,
            pad_token_id=self.tokenizer.pad_token_id,
            **generation_kwargs
        )

    def answer_yes_no(self, prompts: List[str]):
        generation_kwargs = dict(
            max_new_tokens=1,
            output_scores=True,
            return_dict_in_generate=True,
        )
        generation_kwargs.update(self.generation_kwargs)
        outputs = self._get_outputs(prompts, generation_kwargs)
        total_answer_prob = total_option_answer_prob(outputs, self.multiple_choice_option_input_ids)
        yes_no_logprobs = yes_no_answer_logprobs(outputs, self.multiple_choice_option_input_ids)
        completions = []
        for i, prompt in enumerate(prompts):
            prompt_logprobs = dict(Yes=yes_no_logprobs[i][0], No=yes_no_logprobs[i][1])
            text = max(prompt_logprobs.items(), key=lambda t: t[1])[0]  # type: ignore
            meta = dict(**generation_kwargs, total_answer_prob=total_answer_prob[i].item())
            completions.append(Completion(prompt, text, meta=meta, logprobs=[prompt_logprobs]))  # type: ignore
        return completions

    def answer_free_response(self, prompts: List[str]) -> List[Completion]:
        outputs = self._get_outputs(prompts, self.generation_kwargs)
        texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)  # type: ignore
        return [
            Completion(prompt, clean_model_output(prompt, text), meta=self.generation_kwargs, logprobs=None)
            for prompt, text in zip(prompts, texts)
        ]

    @torch.inference_mode()
    def __call__(self, prompts: List[str]):
        return self.answer_yes_no(prompts) if self.multiple_choice else self.answer_free_response(prompts)
