from transformers import (
    StoppingCriteria,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    GenerationConfig,
    BitsAndBytesConfig,
)
import os
import openai
import time
import torch
import tqdm
import pandas as pd
try:
    from vllm import LLM, SamplingParams
    import vllm
except:
    print('vllm is not installed')

import utils
import data
import merge
# 1. 用别人的模型需要先gen一下看下special token是否正确，比如wizard使用'</s>'=[< / s >]
# 2. use_cache 可能被设成false了
# 2. batch gen的问题是不能及时的结尾导致生成非常慢；single batch 跑不满
# transformer gen: 1h batch=32 use_cache=False batch=1 1h30min
# vllm 20min (尽管gpu看起来没跑满)

TMP = None

class StoppingWordCriteria(StoppingCriteria):

    def __init__(self, tokenizer, batch_mask, stop_words):
        self.tokenizer = tokenizer
        self.stop_words = stop_words
        self.batch_mask = batch_mask

    # the problem is that when batch gen, it ends batch , not single
    def __call__(self, input_ids, scores, **kwargs):
        input_ids[:, self.batch_mask.size(1):]
        generated_text = self.tokenizer.batch_decode(input_ids)
        if all(['</s>' in text for text in generated_text]):
            return True  # Stop generation
        return False  # Continue generation

    def __len__(self):
        return 1

    def __iter__(self):
        yield self

class ChatGPT():
    #support:
    # gpt-3.5-turbo / gpt-3.5-turbo-0613
    # gpt-3.5-turbo-16k / gpt-3.5-turbo-16k-0613
    # gpt-4 / gpt-4-0613
    def __init__(self, config):
        openai.api_base = 'https://openkey.cloud/v1'  # 换成代理，一定要加v1
        openai.api_key = 'sk-rVjHZeruZMIZBVrdIRwpcmXChCpjVaUBM1w4BWHZvVzIzVkQ'
        self.model_type = config.model_type  #"gpt-3.5-turbo"

    def generate(self, prompt):

        loop = True
        while loop:
            try:
                completion = openai.ChatCompletion.create(
                    model=self.model_type, messages=[{
                        "role": "user",
                        "content": prompt
                    }]
                )
                loop = False
            except Exception as e:
                if 'You exceeded your current quota' in e._message or '用户额度不足' in e._message:
                    exit()
                if 'That model is currently overloaded with other requests.' in e._message:
                    time.sleep(10)
                else:
                    time.sleep(60)

        response = completion["choices"][0]["message"]["content"]
        return response

class Seq2SeqLM():

    def __init__(self, config):
        self.config = config
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            config.model_type,
            device_map={'': 0},
            torch_dtype=torch.float16,
            quantization_config=None if not config.quant else BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                llm_int8_has_fp16_weight=True,
            )
        )
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_type, model_max_length=8192)
        # self.generation_config = GenerationConfig(
        #     **config.__dict__
        # )

    def generate(
        self,
        prompt,
        batch_size=32,
        generation_kwargs: SamplingParams = None,
    ):

        # .replace('.', '').strip()+'.'
        def postprocess(text):
            return text

        self.tokenizer.padding_side = "left"
        generation_kwargs = generation_kwargs.__dict__

        batch_size = min(len(prompt), batch_size)
        for i in tqdm.trange(0, len(prompt), batch_size, desc='generating ', leave=False):
            j = min(len(prompt), i + batch_size)

            input = self.tokenizer(
                prompt[i:j], return_tensors='pt', padding=True, truncation=True
            ).to(self.model.device)
            # seq2seq, don't contain the prompt itself
            output = self.model.generate(
                **input,
                max_new_tokens=2048,
                **generation_kwargs,
            )
            texts = self.tokenizer.batch_decode(output, skip_special_tokens=True)
            output = [postprocess(text) for text in texts]
        return output

class CausalLM():

    def __init__(self, model, quant=True, **kwargs):

        self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
        self.tokenizer.pad_token_id = 0
        self.model = AutoModelForCausalLM.from_pretrained(
            model,
            device_map={
                '': 0
            },
            torch_dtype=torch.bfloat16,
            # use_flash_attention_2=True,
            trust_remote_code=True,
            quantization_config=None if not quant else BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                llm_int8_has_fp16_weight=True,
            )
        ).eval()
        self.model.config.use_cache = True

    def generate_single(
        self,
        prompt,
        **kwargs,
    ):

        # .replace('.', '').strip()+'.'
        def postprocess(text):
            return text

        input = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
        prompt_length = input["input_ids"].size(1)
        output = self.model.generate(
            **input,
            max_new_tokens=128,
        )
        texts = self.tokenizer.batch_decode(output[:, prompt_length:], skip_special_tokens=True)
        return postprocess(texts[0])

    def generate(
        self,
        prompt,
        generation_kwargs=None,
        batch_size=8,
    ):

        if generation_kwargs is None:
            # greedy search
            generation_kwargs = {
                'num_beams': 1,
                'do_sample': False,
                # 'temperature': 0.0,
                # 'top_k': -1,
                # 'top_p': 1,
                'max_new_tokens': 1024,
            }
        stop_words = ['Instruction:', 'Instruction', 'Response:', 'Response', '</s>']

        self.tokenizer.padding_side = "left"
        outputs = []
        batch_size = min(len(prompt), batch_size)
        for i in tqdm.trange(0, len(prompt), batch_size, desc='generating ', leave=False):
            j = min(len(prompt), i + batch_size)

            input = self.tokenizer(
                prompt[i:j], return_tensors='pt', padding=True, truncation=True
            ).to(self.model.device)

            output = self.model.generate(
                **input,
                # streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
                stopping_criteria=StoppingWordCriteria(self.tokenizer, input["attention_mask"], stop_words),
                **generation_kwargs,
            )
            # remove left padding + prompt
            outputs += output[:, input["attention_mask"].size(1):]

        texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return texts

def _run_engine(self, use_tqdm: bool):
    # Initialize tqdm.
    if use_tqdm:
        num_requests = self.llm_engine.get_num_unfinished_requests()
        pbar = tqdm.tqdm(total=num_requests, desc="Generating ")
    # Run the engine.
    outputs = []
    try:
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
            for output in step_outputs:
                if output.finished:
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
    # If has error, save outputs first
    except:
        pass
    # Sort the outputs by request ID.
    # This is necessary because some requests may be finished earlier than
    # its previous requests.
    outputs = sorted(outputs, key=lambda x: int(x.request_id))
    return outputs

class VLLM():

    def __init__(
        self,
        model,
        model_param = None,
        tensor_parallel_size=1,
    ):
        # to rescue the break outputs
        setattr(LLM, '_run_engine', _run_engine)
        
        if model_param is not None:
            # directly input param into vllm 
            # 目的：直接送入merged parameters 而不是 每次都保存到磁盘让vllm加载 因此需要修改vllm load_model的方法
            # vllm load load_model的方法 可见 $HOME/miniconda3/envs/vllm/lib/python3.9/site-packages/vllm/model_executor/models/llama.py(284), 是用到 vllm.model_executor.weight_utils 里边的函数hf_model_weights_iterator

            # 0. 不优雅的写法：直接重写用到 hf_model_weights_iterator 的方法 vllm.model_executor.models.llama.load_weights
            # 1. 错误写法：直接覆盖 vllm.model_executor.weight_utils.hf_model_weights_iterator
            # import vllm.model_executor.weight_utils
            # vllm.model_executor.weight_utils.hf_model_weights_iterator = _direct_model_param_iterator

            # 2. 正确的写法
            def _direct_model_param_iterator(
                model_name_or_path: str,
                cache_dir = None,
                use_np_cache: bool = False,
            ):
                for name, param in model_param.items():
                    yield name, param
                torch.cuda.empty_cache()
            
            import vllm.model_executor.models.llama
            vllm.model_executor.models.llama.hf_model_weights_iterator = _direct_model_param_iterator

            # 也可以 
            # from mock import patch
            # with patch('vllm.model_executor.weight_utils.hf_model_weights_iterator', _direct_model_param_iterator) : 
            #   xxx

        self.llm = LLM(
            model, # only support for path, cannot directly input
            tensor_parallel_size=tensor_parallel_size,
            trust_remote_code=True,
        )
        # top_p=1 or top_k=-1 <=> consider all
        # beam search: use_beam_search + best_of>1 + temperature = 0
        # greedy sampling:  temperature < 1e-5
        # do sample
        self.sampling_params = SamplingParams(
            # use_beam_search=True, best_of=2,
            temperature=0.0,
            top_p=1,
            top_k=-1,
            max_tokens=2048,
            stop=["Instruction:", "Instruction", "Response:", "Response","</s>"]
        )

    def generate(
        self,
        prompt,
        **kwargs,
    ):
        # batch optimized
        completions = self.llm.generate(prompt, self.sampling_params)
        completions = [x.outputs[0].text for x in completions]
        # if has break
        # if len(completions) != len(prompt):
        #     import pdb
        #     pdb.set_trace()
        return completions


def main(
    *,
    merge_config: str = None, # whether to merge
    model: str = 'Qwen/Qwen-14B', # if not merge, it is the inference model; if merge, it gives config and tokenizer
    type: str = 'VLLM',
    test: str = 'gsm8k',
    num: int=-1,
    name: str =None,
):
    import inspect, types, sys
    frame = inspect.currentframe()
    keys, _, _, args = inspect.getargvalues(frame)
    values = {k: args[k] for k in keys}
    args = types.SimpleNamespace(**values)

    utils.fix_seed(0)
    
    tester = getattr(data, test)()
    dataset = tester.read()

    engine = getattr(sys.modules[__name__], type)(
        model=model,
        model_param=None if not merge_config else merge.run_merge(merge_config)
    )

    outputs = engine.generate(
        dataset['prompts'][:num] if num!=-1 else dataset['prompts']
    )
    if test in ['gsm8k', 'MATH']:
        metrics = tester.test({
            'outputs': outputs, 
            'labels': dataset['labels']
        })
    elif test in ['human_eval', 'mbpp']:
        dataset['labels'] = dataset['task_ids']
        os.makedirs(f"results/{test}",exist_ok=True)
        metrics = tester.test({
            'outputs': outputs, 
            'task_ids': dataset['task_ids'],
            'path': f"results/{test}/{name}.jsonl"
        })

    if name is None:
        name= f'{model}-{test}' if num == -1 else f'{model}-{num}:{test}'
    # utils.to_markdown(metrics)
    print(metrics)

    dataset.update({
        'outputs': outputs,
    })
    dataset = [{
        'prompts': d[0], 
        'labels': d[1],
        'outputs': d[2],
    } for d in zip(dataset['prompts'], dataset['labels'], dataset['outputs'])]
    utils.to_jsonl(dataset, f"results/{name.replace('/', '_')}.jsonl")

    path = f'results/_result.xlsx'
    df = pd.DataFrame(metrics, index=[name]) #.reindex(columns=columns)
    if os.path.exists(path):
        previous = pd.read_excel(path,index_col=0)
        df = pd.concat([previous,df])
    utils.to_markdown(df, 'results/_result.md')
    df.to_excel(path,index=True)


if __name__ == '__main__':
    import defopt
    try:
        defopt.run(main)
    except:
        import sys, pdb, bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type, value)
        pdb.post_mortem(tb)