import time
import torch
from torch import nn
from transformers import StoppingCriteriaList
from bild_model import EOSStoppingCriteria
from utils import calculate_prompt_length


def warmup(model, tokenizer, device):
    prompt = "How do you fine-tune a large language model?"
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.</s>\n"
        f"<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors="pt").to(device)
    model.large.generate(**model_inputs, max_new_tokens=1)


def generate_answer_bild(prompt, model, tokenizer, device, max_new_tokens, temperature):
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a domain expert.</s>\n"
        f"<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors="pt").to(device)
    input_ids = model_inputs["input_ids"]
    # max_length = input_ids.size(1) + max_length

    with torch.no_grad():
        output, generation_times, model_types, token_confidences, streamed_token_list = model.generate(
            input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature
        )

        answer = tokenizer.decode(output[0], skip_special_tokens=True)

        # print("--Model types")
        # print(model_types)

        # Calculate proportions of small and large model tokens
        num_tokens_generated = len(generation_times)
        num_large_tokens = sum(1 for model_type in model_types if model_type == "large")
        num_small_tokens = num_tokens_generated - num_large_tokens

        large_token_proportion = (
            num_large_tokens / num_tokens_generated if num_tokens_generated else 0
        )
        small_token_proportion = (
            num_small_tokens / num_tokens_generated if num_tokens_generated else 0
        )

        # Calculate an average confidence score
        avg_confidence = (
            sum(conf[1] for conf in token_confidences) / num_tokens_generated
            if num_tokens_generated
            else 0
        )

        concat_answer = tokenizer.decode(input_ids[0], skip_special_tokens=True) + "".join(streamed_token_list)
        # print(streamed_token_list)
        print("The Concat Answer is: ", concat_answer)
        # assert concat_answer == answer, f"Concatenated answer different from the final answer!\nThe Concat Answer is: {concat_answer}\nThe Correct Answer is: {answer}"

        return (
            answer,
            avg_confidence,
            large_token_proportion,
            small_token_proportion,
            streamed_token_list
        )


