import asyncio
import gc
import json
import multiprocessing
import os
import time
from importlib import import_module
from typing import Callable, Optional

import openai
import torch
import tqdm
from huggingface_hub import list_repo_files
from peft import PeftConfig, PeftModel
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import StoppingCriteria
from vllm import LLM, SamplingParams, RequestOutput

from src.dispatch_openai_requests import dispatch_openai_chat_requests


class KeyWordsCriteria(StoppingCriteria):
    def __init__(self, stop_id_sequences):
        assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
        self.stop_sequences = stop_id_sequences

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        sequences_should_be_stopped = []
        for i in range(input_ids.shape[0]):
            sequence_should_be_stopped = False
            for stop_sequence in self.stop_sequences:
                if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence:
                    sequence_should_be_stopped = True
                    break
            sequences_should_be_stopped.append(sequence_should_be_stopped)
        return all(sequences_should_be_stopped)


@torch.no_grad()
def generate_completions(model, tokenizer, prompts, model_name_or_path: str, batch_size=1, stop_id_sequences=None, add_special_tokens=True,
                         disable_tqdm=False, **generation_kwargs):
    generations = []
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")

    if "llama-3" in model_name_or_path.lower() or "llama3" in model_name_or_path.lower():
        # Llama3 uses special <|eot_id|> token for end of turn
        llama3_stop_id = [128009]
        stop_id_sequences = [] if stop_id_sequences is None else stop_id_sequences
        if llama3_stop_id not in stop_id_sequences:
            stop_id_sequences.append(llama3_stop_id)

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt",
                                      add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.cuda()
            attention_mask = attention_mask.cuda()

        try:
            batch_outputs = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                eos_token_id=tokenizer.eos_token_id,
                stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
                **generation_kwargs
            )

            # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
            # so some outputs still have the stop sequence, which we need to remove.
            if stop_id_sequences:
                for output_idx in range(batch_outputs.shape[0]):
                    for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
                        if any(batch_outputs[output_idx,
                               token_idx: token_idx + len(stop_sequence)].tolist() == stop_sequence for stop_sequence in
                               stop_id_sequences):
                            batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
                            break

            # remove the prompt from the output
            # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
            # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
            # space is important for some tasks (e.g., code completion).
            batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
            batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            # duplicate the prompts to match the number of return sequences
            batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
            batch_generations = [
                output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
            ]
        except Exception as e:
            print("Error when generating completions for batch:")
            print(batch_prompts)
            print("Error message:")
            print(e)
            print("Use empty string as the completion.")
            batch_generations = [""] * len(batch_prompts) * num_return_sequences

        generations += batch_generations

        # for prompt, generation in zip(batch_prompts, batch_generations):
        #     print("========")
        #     print(prompt)
        #     print("--------")
        #     print(generation)

        if not disable_tqdm:
            progress.update(len(batch_prompts) // num_return_sequences)

    assert len(generations) == len(
        prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    return generations


@torch.no_grad()
def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1,
                              return_token_predictions=False, add_special_tokens=True, disable_tqdm=False):
    predictions, probs = [], []
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")

    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i: i + batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt",
                                      add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.cuda()
            attention_mask = attention_mask.cuda()

        batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
        batch_probs = torch.softmax(batch_logits, dim=-1)
        if candidate_token_ids is not None:
            batch_probs = batch_probs[:, candidate_token_ids]
        batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
        if return_token_predictions:
            if candidate_token_ids is not None:
                candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
                batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
            else:
                batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
            predictions += batch_predictions
        else:
            predictions += batch_prediction_indices.tolist()
        probs += batch_probs.tolist()

        if not disable_tqdm:
            progress.update(len(batch_prompts))

    assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
    return predictions, probs


@torch.no_grad()
def score_completions(model, tokenizer, scoring_examples, batch_size=1, aggregation="sum", disable_tqdm=False):
    '''
    Each scoring example is a dict, which contains the following keys:
    - prompt: the prompt to score
    - completions: a list of completions to score
    '''

    # unroll the scoring examples
    unrolled_examples = []
    for scoring_example in scoring_examples:
        prompt = scoring_example["prompt"]
        for completion in scoring_example["completions"]:
            unrolled_examples.append({
                "prompt": prompt,
                "completion": completion
            })

    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(unrolled_examples), desc="Scoring Completions")

    scores = []
    for i in range(0, len(unrolled_examples), batch_size):
        batch_prompts = [example["prompt"] for example in unrolled_examples[i:i + batch_size]]
        batch_examples = [
            (example["prompt"] if example["prompt"][-1] in ["\n", " "] else example["prompt"] + " ")
            + example["completion"] for example in unrolled_examples[i:i + batch_size]
        ]
        tokenized_batch = tokenizer(batch_examples, padding="longest", return_tensors="pt")
        if model.device.type == "cuda":
            tokenized_batch = {
                key: value.cuda() for key, value in tokenized_batch.items()
            }
        tokenized_batch.pop("token_type_ids", None)
        outputs = model(**tokenized_batch)

        for example_idx, (prompt, example) in enumerate(zip(batch_prompts, batch_examples)):
            tokenized_prompt = tokenizer(prompt, padding=False, return_tensors="pt").input_ids.squeeze(0)
            tokenized_example = tokenizer(example, padding=False, return_tensors="pt").input_ids.squeeze(0)
            completion_ids = tokenized_example[len(tokenized_prompt):]

            # get the logits for the entire example, removing the padding logits
            if tokenizer.padding_side == "right":
                example_logits = outputs.logits[example_idx, :len(tokenized_example), :]
            else:
                example_logits = outputs.logits[example_idx, -len(tokenized_example):, :]

            # get the logits for the completion portion - note we need to shift the index left by 1 because logits are computed for the next token
            completion_logits = example_logits[len(tokenized_prompt) - 1:len(tokenized_example) - 1, :]
            completion_log_probs = torch.log_softmax(completion_logits, dim=-1)[
                range(len(completion_ids)), completion_ids]

            if aggregation == "sum":
                score = completion_log_probs.sum().item()
            elif aggregation == "mean":
                score = completion_log_probs.mean().item()
            elif aggregation == "max":
                score = completion_log_probs.max().item()
            else:
                raise ValueError("Invalid aggregation method: {}".format(aggregation))
            scores.append(score)

        if not disable_tqdm:
            progress.update(len(batch_examples))

    # roll up the scores
    rolled_up_scores = {}
    for unrolled_example, score in zip(unrolled_examples, scores):
        prompt = unrolled_example["prompt"]
        completion = unrolled_example["completion"]
        if prompt not in rolled_up_scores:
            rolled_up_scores[prompt] = {}
        rolled_up_scores[prompt][completion] = score

    return rolled_up_scores


def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
    try:
        # Try first if model on a Hub repo
        repo_files = list_repo_files(model_name_or_path, revision=revision)
    except:
        # If not, check local repo
        repo_files = os.listdir(model_name_or_path)
    return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files


def load_hf_lm_and_tokenizer(
        model_name_or_path,
        revision="main",
        tokenizer_name_or_path=None,
        device_map="auto",
        torch_dtype="auto",
        load_in_8bit=False,
        convert_to_half=False,
        gptq_model=False,
        use_fast_tokenizer=True,
        padding_side="left",
        token=os.getenv("HF_TOKEN", None)
):
    from transformers import AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, GPTNeoXForCausalLM
    if is_adapter_model(model_name_or_path):
        peft_config = PeftConfig.from_pretrained(model_name_or_path)
        base_model = AutoModelForCausalLM.from_pretrained(
            peft_config.base_model_name_or_path, revision=revision
        )
        model = PeftModel.from_pretrained(
            base_model,
            model_name_or_path,
        )
        if torch.cuda.is_available():
            model = model.cuda()
    else:
        if gptq_model:
            from auto_gptq import AutoGPTQForCausalLM
            model_wrapper = AutoGPTQForCausalLM.from_quantized(
                model_name_or_path, device="cuda:0", use_triton=True
            )
            model = model_wrapper.model
        elif load_in_8bit:
            model = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                device_map=device_map,
                load_in_8bit=True,
                token=token,
                revision=revision,
            )
        else:
            if device_map:
                model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map,
                                                             torch_dtype=torch_dtype, token=token, revision=revision)
            else:
                model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch_dtype, token=token,
                                                             device_map="auto", revision=revision)
                if torch.cuda.is_available():
                    model = model.cuda()
            if convert_to_half:
                model = model.half()
    model.eval()

    if not tokenizer_name_or_path:
        tokenizer_name_or_path = model_name_or_path
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, token=token, revision=revision)
    except:
        # some tokenizers (e.g., GPTNeoXTokenizer) don't have the slow or fast version, so we just roll back to the default one
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, token=token, revision=revision)
    # set padding side to left for batch generation
    tokenizer.padding_side = padding_side
    # set pad token to eos token if pad token is not set (as is the case for llama models)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # for OPT and Pythia models, we need to set tokenizer.model_max_length to model.config.max_position_embeddings
    # to avoid wrong embedding index.
    if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM):
        tokenizer.model_max_length = model.config.max_position_embeddings
        print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(
            model.config.max_position_embeddings))

    return model, tokenizer


def remove_repeated_patterns(s):
    # Function to check if the substring pattern is repeated more than 10 times
    def is_repeated(sub):
        return s.count(sub) > 10

    # Main logic to find and replace the repeated patterns
    for length in range(1, len(s) // 10 + 1):
        for start in range(len(s) - length + 1):
            sub = s[start:start + length]
            if is_repeated(sub):
                s = s.replace(sub, '')
                return remove_repeated_patterns(s)  # Recur in case of multiple patterns
    return s


def query_openai_chat_model(engine, instances, output_path=None, batch_size=10, retry_limit=5,
                            reuse_existing_outputs=True, **completion_kwargs):
    '''
    Query OpenAI chat model and save the results to output_path.
    `instances` is a list of dictionaries, each dictionary contains a key "prompt" and a key "id".
    '''
    existing_data = {}
    if reuse_existing_outputs and output_path is not None and os.path.exists(output_path):
        with open(output_path, "r") as f:
            for line in f:
                instance = json.loads(line)
                existing_data[instance["id"]] = instance

    # by default, we use temperature 0.0 to get the most likely completion.
    if "temperature" not in completion_kwargs:
        completion_kwargs["temperature"] = 0.0

    results = []
    if output_path is not None:
        fout = open(output_path, "w")

    retry_count = 0
    progress_bar = tqdm.tqdm(total=len(instances))
    for i in range(0, len(instances), batch_size):
        batch = instances[i:i + batch_size]
        if all([x["id"] in existing_data for x in batch]):
            results.extend([existing_data[x["id"]] for x in batch])
            if output_path is not None:
                for instance in batch:
                    fout.write(json.dumps(existing_data[instance["id"]]) + "\n")
                    fout.flush()
            progress_bar.update(batch_size)
            continue
        messages_list = []
        for instance in batch:
            messages = [{"role": "user", "content": instance["prompt"]}]
            messages_list.append(messages)
        while retry_count < retry_limit:
            try:
                outputs = asyncio.run(
                    dispatch_openai_chat_requests(
                        messages_list=messages_list,
                        model=engine,
                        **completion_kwargs,
                    ))
                retry_count = 0
                break
            except Exception as e:
                if isinstance(e, openai.BadRequestError) and 'repetitive patterns' in e.message:
                    print('Removing repetitive patterns from prompt')
                    for messages in messages_list:
                        user_msg = messages[0]
                        user_msg['content'] = remove_repeated_patterns(user_msg['content'])

                retry_count += 1
                print(f"Error while requesting OpenAI API.")
                print(e)
                print(f"Sleep for {30 * retry_count} seconds.")
                time.sleep(30 * retry_count)
                print(f"Retry for the {retry_count} time.")
        if retry_count == retry_limit:
            raise RuntimeError(f"Failed to get response from OpenAI API after {retry_limit} retries.")
        assert len(outputs) == len(batch)
        for instance, output in zip(batch, outputs):
            instance[f"response"] = output.choices[0].message.content
            instance["response_metadata"] = output.model_dump()
            results.append(instance)
            if output_path is not None:
                fout.write(json.dumps(instance) + "\n")
                fout.flush()
        progress_bar.update(batch_size)
    if output_path is not None:
        fout.close()
    return results


def query_openai_model(engine, instances, output_path=None, batch_size=10, retry_limit=5, reuse_existing_outputs=True,
                       **completion_kwargs):
    '''
    Query OpenAI chat model and save the results to output_path.
    `instances` is a list of dictionaries, each dictionary contains a key "prompt" and a key "id".
    '''
    from dispatch_openai_requests import dispatch_openai_prompt_requests
    existing_data = {}
    if reuse_existing_outputs and output_path is not None and os.path.exists(output_path):
        with open(output_path, "r") as f:
            for line in f:
                instance = json.loads(line)
                existing_data[instance["id"]] = instance

    # by default, we use temperature 0.0 to get the most likely completion.
    if "temperature" not in completion_kwargs:
        completion_kwargs["temperature"] = 0.0

    results = []
    if output_path is not None:
        fout = open(output_path, "w")

    retry_count = 0
    progress_bar = tqdm.tqdm(total=len(instances))
    for i in range(0, len(instances), batch_size):
        batch = instances[i:i + batch_size]
        if all([x["id"] in existing_data for x in batch]):
            results.extend([existing_data[x["id"]] for x in batch])
            if output_path is not None:
                for instance in batch:
                    fout.write(json.dumps(existing_data[instance["id"]]) + "\n")
                    fout.flush()
            progress_bar.update(batch_size)
            continue
        messages_list = []
        for instance in batch:
            messages = instance["prompt"]
            messages_list.append(messages)
        while retry_count < retry_limit:
            try:
                outputs = asyncio.run(
                    dispatch_openai_prompt_requests(
                        prompt_list=messages_list,
                        model=engine,
                        **completion_kwargs,
                    ))
                retry_count = 0
                break
            except Exception as e:
                retry_count += 1
                print(f"Error while requesting OpenAI API.")
                print(e)
                print(f"Sleep for {30 * retry_count ** 2} seconds.")
                time.sleep(30 * (retry_count ** 2))
                print(f"Retry for the {retry_count} time.")
        if retry_count == retry_limit:
            raise RuntimeError(f"Failed to get response from OpenAI API after {retry_limit} retries.")
        assert len(outputs) == len(batch)
        for instance, output in zip(batch, outputs):
            instance[f"output"] = output.choices[0].text
            instance["response_metadata"] = output.model_dump()
            results.append(instance)
            if output_path is not None:
                fout.write(json.dumps(instance) + "\n")
                fout.flush()
        progress_bar.update(batch_size)
    return results


def load_vllm_model(model_name_or_path,
                    dtype="bfloat16",
                    tokenizer_mode="auto",
                    trust_remote_code=True,
                    hf_revision="main") -> LLM:
    import vllm
    num_devices = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(','))
    model = vllm.LLM(
        model=model_name_or_path,
        revision=hf_revision,
        dtype=dtype,
        tensor_parallel_size=num_devices,
        tokenizer_mode=tokenizer_mode,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.6,
        max_model_len=32768,
    )
    return model


def delete_vllm_model(model: LLM):
    print("Deleting vllm model")
    # Import function used for cleanup; could be in two places depending on vllm version
    try:
        from vllm.distributed.parallel_state import destroy_model_parallel
    except Exception:
        from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
    # Make sure to clean up vllm properly
    # A bit hacky to try to make it work with multiple versions
    destroy_model_parallel()
    try:
        del model.llm_engine.model_executor.driver_worker
    except Exception:
        pass
    try:
        del model.llm_engine.model_executor
    except Exception:
        pass
    del model
    gc.collect()
    torch.cuda.empty_cache()
    try:
        import ray
        ray.shutdown()
    except Exception:
        pass
    time.sleep(60)
    print("Deleted vllm model")


def dynamic_import_function(function_path) -> Callable:
    '''
    Dynamically import a function from a path string (e.g., "module.submodule.my_function")
    '''
    module_path, function_name = function_path.rsplit(".", 1)
    module = import_module(module_path)
    function = getattr(module, function_name)
    return function


def inference_with_vllm(
        prompts: list[str],
        model: LLM,
        model_name_or_path: str,  # Used to set special generation config
        do_chat_formatting: bool = False,
        tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
        use_tqdm: bool = False,
        return_full_outputs: bool = False,  # Whether to return the direct vllm output objects
        temperature: float = 1.0,
        top_p: float = 1.0,
        max_tokens: int = 2048,
        n: int = 1,
        stop: Optional[str | list[str]] = None,
        **sampling_args
) -> list[str] | list[RequestOutput]:
    if do_chat_formatting:
        assert tokenizer is not None, "Chat formatting requires tokenizer"
        conversation_prompts = [[{'role': 'user', 'content': p}] for p in prompts]
        formatted_prompts = [tokenizer.apply_chat_template(p, tokenize=False) for p in conversation_prompts]
    else:
        formatted_prompts = prompts

    if n > 1:
        assert return_full_outputs, "n > 1 requires return_full_outputs to be True."

    if "llama-3" in model_name_or_path.lower() or "llama3" in model_name_or_path.lower():
        # Llama3 uses special <|eot_id|> token for end of turn
        llama3_stop_id = 128009
        stop_token_ids = sampling_args.get("stop_token_ids", [])
        if llama3_stop_id not in stop_token_ids:
            stop_token_ids.append(llama3_stop_id)
        sampling_args["stop_token_ids"] = stop_token_ids

    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        n=n,
        stop=stop,
        **sampling_args
    )
    outputs = model.generate(prompts=formatted_prompts, sampling_params=sampling_params, use_tqdm=use_tqdm)
    if return_full_outputs:
        return outputs

    results = [it.outputs[0].text for it in outputs]

    return results


def create_and_inference_with_vllm(
        prompts: list[str],
        model_name_or_path: str,
        hf_revision: str = "main",
        tokenizer_mode: str = "auto",
        do_chat_formatting: bool = False,
        tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
        use_tqdm: bool = False,
        return_full_outputs: bool = False,  # Whether to return the direct vllm output objects
        temperature: float = 1.0,
        top_p: float = 1.0,
        max_tokens: int = 2048,
        n: int = 1,
        stop: Optional[str | list[str]] = None,
        result_pipe=None,
        **sampling_args
) -> list[str] | list[RequestOutput]:
    llm = load_vllm_model(model_name_or_path, tokenizer_mode=tokenizer_mode, hf_revision=hf_revision)
    result = inference_with_vllm(
        prompts=prompts,
        model=llm,
        model_name_or_path=model_name_or_path,
        do_chat_formatting=do_chat_formatting,
        tokenizer=tokenizer,
        use_tqdm=use_tqdm,
        return_full_outputs=return_full_outputs,
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        n=n,
        stop=stop,
        **sampling_args
    )

    if result_pipe is not None:
        result_pipe.send(result)
    return result


def subprocess_inference_with_vllm(
        prompts: list[str],
        model_name_or_path: str,
        hf_revision: str = "main",
        tokenizer_mode: str = "auto",
        do_chat_formatting: bool = False,
        tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
        use_tqdm: bool = False,
        return_full_outputs: bool = False,  # Whether to return the direct vllm output objects
        temperature: float = 1.0,
        top_p: float = 1.0,
        max_tokens: int = 2048,
        n: int = 1,
        stop: Optional[str | list[str]] = None,
        **sampling_args
) -> list[str] | list[RequestOutput]:
    ctx = multiprocessing.get_context("spawn")
    parent_conn, child_conn = ctx.Pipe()
    subprocess = ctx.Process(
        target=create_and_inference_with_vllm,
        args=(
            prompts, model_name_or_path, hf_revision, tokenizer_mode, do_chat_formatting, tokenizer, use_tqdm,
            return_full_outputs, temperature, top_p, max_tokens, n, stop, child_conn
        ),
        kwargs=sampling_args
    )
    subprocess.start()
    result = parent_conn.recv()
    subprocess.join()
    return result
