import gc
from threading import Thread
import torch
import transformers
from transformers import (
    GenerationConfig,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer,
)


@torch.inference_mode()
def generate_stream_codet5p(
    model,
    tokenizer,
    params,
    device,
    context_len=2048,
    stream_interval=2,
    judge_sent_end=False,
):
    prompt = params["prompt"]
    temperature = float(params.get("temperature", 1.0))
    repetition_penalty = float(params.get("repetition_penalty", 1.0))
    top_p = float(params.get("top_p", 1.0))
    top_k = int(params.get("top_k", 50))  # -1 means disable
    max_new_tokens = int(params.get("max_new_tokens", 1024))
    stop_token_ids = params.get("stop_token_ids", None) or []
    stop_token_ids.append(tokenizer.eos_token_id)

    decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
    streamer = TextIteratorStreamer(tokenizer, **decode_config)
    encoding = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = encoding.input_ids
    encoding["decoder_input_ids"] = encoding["input_ids"].clone()
    input_echo_len = len(input_ids)

    generation_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=temperature >= 1e-5,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=10,
        top_p=top_p,
        top_k=top_k,
        eos_token_id=stop_token_ids,
    )

    class CodeBlockStopper(StoppingCriteria):
        def __call__(
            self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
        ) -> bool:
            # Code-completion is open-end generation.
            # We check \n\n to stop at end of a code block.
            if list(input_ids[0][-2:]) == [628, 198]:
                return True
            return False

    gen_kwargs = dict(
        **encoding,
        streamer=streamer,
        generation_config=generation_config,
        stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
    )
    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()
    i = 0
    output = ""
    for new_text in streamer:
        i += 1
        output += new_text
        if i % stream_interval == 0 or i == max_new_tokens - 1:
            yield {
                "text": output,
                "usage": {
                    "prompt_tokens": input_echo_len,
                    "completion_tokens": i,
                    "total_tokens": input_echo_len + i,
                },
                "finish_reason": None,
            }
        if i >= max_new_tokens:
            break

    if i >= max_new_tokens:
        finish_reason = "length"
    else:
        finish_reason = "stop"

    yield {
        "text": output,
        "usage": {
            "prompt_tokens": input_echo_len,
            "completion_tokens": i,
            "total_tokens": input_echo_len + i,
        },
        "finish_reason": finish_reason,
    }
    thread.join()

    # clean
    gc.collect()
    torch.cuda.empty_cache()
    if device == "xpu":
        torch.xpu.empty_cache()
