import torch
from torch import nn
from tqdm import tqdm


# from typing_extensions import override


class StreamingLLM(object):
    model = None
    kv_cache = None
    tokenizer = None
    max_gen_len = None
    device = None

    def __init__(self, model=None, tokenizer=None, kv_cache=None, max_gen_len=50):
        # super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.kv_cache = kv_cache
        self.max_gen_len = 50
        self.device = model.device

    # @override
    def eval(self):
        self.model.eval()

    def greedy_generate(self, input_ids, past_key_values, max_gen_len):
        outputs = self.model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_ids = [pred_token_idx]
        pos = 0
        for _ in range(max_gen_len - 1):
            outputs = self.model(
                input_ids=pred_token_idx,
                past_key_values=past_key_values,
                use_cache=True,
            )
            past_key_values = outputs.past_key_values
            pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
            generated_ids.append(pred_token_idx)
            # generated_text = (
            #     tokenizer.decode(
            #         generated_ids,
            #         skip_special_tokens=True,
            #         clean_up_tokenization_spaces=True,
            #         spaces_between_special_tokens=False,
            #     )
            #     .strip()
            #     .split(" ")
            # )

            # now = len(generated_text) - 1
            # if now > pos:
            #     print(" ".join(generated_text[pos:now]), end=" ", flush=True)
            #     pos = now

            if pred_token_idx == self.tokenizer.eos_token_id:
                break
        # print(" ".join(generated_text[pos:]), flush=True)
        if len(generated_ids) > 1:
            outputs = torch.concat((input_ids, torch.concat(generated_ids).squeeze()[None,]), dim=1)
        else:
            outputs = torch.concat((input_ids, torch.concat(generated_ids)), dim=1)
        return outputs

    def generate(self, input_ids=None, **kwargs):
        past_key_values = None
        # for idx, prompt in enumerate(prompts):
        #     prompt = "USER: " + prompt + "\n\nASSISTANT: "
        #     print("\n" + prompt, end="")
        #     input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        #     input_ids = input_ids.to(model.device)
        seq_len = input_ids.shape[1]
        if self.kv_cache is not None:
            space_needed = seq_len + self.max_gen_len
            past_key_values = self.kv_cache.evict_for_space(past_key_values, space_needed)

        outputs = self.greedy_generate(input_ids, past_key_values, max_gen_len=kwargs.get("max_new_tokens"))
        return outputs


    def forward_pll(self, input_ids, **kwargs):

        seq_len = input_ids.size(1)
        print(f"seq_len: {seq_len}")

        past_key_values = None


        # outputs = self.model(
        #     input_ids,
        #     past_key_values=past_key_values,
        #     use_cache=True,
        # )
        # return outputs

        pbar = tqdm(range(0, seq_len))
        # pbar = tqdm(range(0, 10))


        all_logits = []

        for idx in pbar:
            input_ids_token = input_ids[:, idx: idx + 1].to(self.device)
            # with torch.no_grad():
            outputs = self.model(
                input_ids_token,
                past_key_values=past_key_values,
                use_cache=True,
            )
            logits = outputs.logits.view(-1, self.model.config.vocab_size)
            past_key_values = outputs.past_key_values
            if self.kv_cache is not None:
                past_key_values = self.kv_cache(past_key_values)

            all_logits.append(logits)

        outputs.logits = torch.concat(all_logits)[None,]
        return outputs
