

import torch
import pandas as pd
import csv
from bild_model import LlamaBiLDModel
from generate import generate_answer_bild, warmup
from utils import setup, calculate_bertscore

# Load prompts and corresponding LLaMA 2 responses from CSV
llama2_responses_file = "input_data/llama2-generated-answers-700.csv"
llama2_data = pd.read_csv(llama2_responses_file)
llama2_data = llama2_data.iloc[0:100]

# Initialize the BiLD model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name_large = "meta-llama/Llama-2-7b-chat-hf"
model_name_small = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer_name_large = "meta-llama/Llama-2-7b-chat-hf"
tokenizer_name_small = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

large_model, small_model, large_tokenizer, small_tokenizer = setup(
    model_name_large,
    model_name_small,
    tokenizer_name_large,
    tokenizer_name_small,
    device,
)

# CSV file to save the results
output_csv = "output_data/bild_finetuning_results.csv"

# Initialize variables for the best scores
best_fallback = None
best_rollback = None
best_quality_score = float("-inf")

# Open the CSV file for writing with buffering 1
with open(output_csv, mode="w", newline="", buffering=1) as file:
    writer = csv.writer(file)
    writer.writerow(
        [
            "alpha_fb",
            "alpha_rb",
            "question",
            "llama2_answer",
            "bild_answer",
            "avg_confidence",
            "large_token_proportion",
            "small_token_proportion",
            "quality_score",
        ]
    )

    # Loop over different fallback and rollback thresholds
    for alpha_fb in [0.8]:
        for alpha_rb in [1, 5, 10, 15, 20, 22, 25, 30]:
            # Create a new BiLD model with the current alpha_fb and alpha_rb
            model = LlamaBiLDModel(
                large_model,
                small_model,
                large_tokenizer,
                small_tokenizer,
                num_small_iters=1,
                fallback_threshold=alpha_fb,
                rollback_threshold=alpha_rb,
            ).to(device)
            warmup(model, large_tokenizer, device)

            # Generate responses for each question
            responses = []
            reference_responses = []
            for index, row in llama2_data.iterrows():
                question = row["question"]
                llama2_answer = row["answer"].strip('"')

                (
                    bild_answer,
                    avg_confidence,
                    large_token_proportion,
                    small_token_proportion,
                ) = generate_answer_bild(
                    question, model, large_tokenizer, device, max_new_tokens=200
                )

                bild_answer = bild_answer.replace("\n", " ")
                bild_answer = bild_answer.split("<|assistant|>")[-1].strip()

                responses.append(bild_answer)
                reference_responses.append(llama2_answer)

                # Calculate quality metrics using BERTScore
                bertscore = calculate_bertscore([bild_answer], [llama2_answer])

                # Write the question, responses, and evaluation metrics to the CSV file
                writer.writerow(
                    [
                        alpha_fb,
                        alpha_rb,
                        question,
                        llama2_answer,
                        bild_answer,
                        avg_confidence,
                        large_token_proportion,
                        small_token_proportion,
                        bertscore,
                    ]
                )

                # Flush the buffer to see results right away
                file.flush()

                print(
                    f"Processed question {index + 1}/{len(llama2_data)} with alpha_fb={alpha_fb} and alpha_rb={alpha_rb}"
                )

                # To save memory
                del bild_answer
                del avg_confidence
                del large_token_proportion
                del small_token_proportion
                torch.cuda.empty_cache()

            # Calculate the average BERTScore for this configuration
            avg_bertscore = calculate_bertscore(responses, reference_responses)

            # Determine if this combination is the best
            if avg_bertscore > best_quality_score:
                best_quality_score = avg_bertscore
                best_fallback = alpha_fb
                best_rollback = alpha_rb

print(
    f"Best fallback: {best_fallback}, Best rollback: {best_rollback}, Best quality score: {best_quality_score}"
)
print("Done! The results have been saved to", output_csv)
