import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

from bert_score import score as bert_score


def calculate_bertscore(responses, reference_responses):
    P, R, F1 = bert_score(
        responses, reference_responses, lang="en", rescale_with_baseline=True
    )
    return F1.mean().item()



def plot_and_save_graph(x, y, xlabel, ylabel, graph_header, file_name):
    plt.figure(figsize=(10, 6))
    plt.plot(x, y, marker="o", linestyle="-", color="b")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(graph_header)
    plt.grid(True)
    plt.savefig(file_name)
    plt.close()


def setup(
    model_name_large, model_name_small, tokenizer_name_large, tokenizer_name_small, device
):
    large_model = AutoModelForCausalLM.from_pretrained(model_name_large, torch_dtype=torch.float16).to(device)
    small_model = AutoModelForCausalLM.from_pretrained(model_name_small, torch_dtype=torch.float16).to(device)
    large_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_large)
    small_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_small)
    return large_model, small_model, large_tokenizer, small_tokenizer


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            prompt = line.strip().strip('"')
            prompts.append(prompt)
    return prompts


def calculate_prompt_length(input_ids, tokenizer):
    # Identify where the prompt ends (after the last special token like </s>)
    special_token_ids = tokenizer.encode("</s>", add_special_tokens=False)
    prompt_end_pos = (
        (input_ids == special_token_ids[-1]).nonzero(as_tuple=True)[1][-1].item() + 1
    )
    # The prompt length is the number of tokens before the actual question starts
    prompt_length = prompt_end_pos
    return prompt_length
