

import torch
import pandas as pd
import json
from bild_model import LlamaBiLDModel
from generate import generate_answer_bild, warmup
from utils import setup, calculate_bertscore
import time


def collect_stats(input_name, temperature, trace_count):
    input_filename = f"input_data/{input_name}.txt"
    output_filename = f"output_data/BILD_{input_name}_temp_{temperature}_t_{trace_count}.json"

    # Load prompts from EK1, EK2, AK1
    with open(input_filename, 'r') as file:
        lines = file.readlines()

    # Remove any trailing newline characters from each line
    input_lines = [line.strip() for line in lines]

    result_json = {"traces": [], "labels": [], "responses": []}

    # Loop over different fallback and rollback thresholds
    for idx, row in enumerate(input_lines):
        # Generate responses for each question
        # print(row)
        for i in range(trace_count):
            question = row

            SEED = int(time.time())  # Seed based on current time
            torch.cuda.manual_seed_all(SEED)

            (
                bild_answer,
                avg_confidence,
                large_token_proportion,
                small_token_proportion,
                streamed_token_list
            ) = generate_answer_bild(
                question, model, large_tokenizer, device, max_new_tokens=100, temperature=temperature
            )

            bild_answer = bild_answer.replace("\n", " ")
            bild_answer = bild_answer.split("<|assistant|>")[-1].strip()

            result_json["responses"].append(bild_answer)
            streamed_token_length_list = [len(s) for s in streamed_token_list]
            result_json["labels"].append(row)
            result_json["traces"].append(streamed_token_length_list)
            print(f"Working on input {idx}, trace {i}, temperature {temperature}, total trace count {trace_count}")
            print(f"Large Token proportion: {large_token_proportion}, Small Token Proportion: {small_token_proportion}")
            print(streamed_token_list)
            print(streamed_token_length_list)

            # Write the question and streamed token labels into the json file

        torch.cuda.empty_cache()

    with open(output_filename, "w") as file:
        json.dump(result_json, file)

    print("Done! The results have been saved to", output_filename)


if __name__ == "__main__":
    # Initialize the BiLD model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    model_name_large = "meta-llama/Llama-2-7b-chat-hf"
    model_name_small = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    tokenizer_name_large = "meta-llama/Llama-2-7b-chat-hf"
    tokenizer_name_small = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    large_model, small_model, large_tokenizer, small_tokenizer = setup(
        model_name_large,
        model_name_small,
        tokenizer_name_large,
        tokenizer_name_small,
        device,
    )


    # Set the threshold
    alpha_fb = 0.7
    alpha_rb = 20

    # Create a new BiLD model with the current alpha_fb and alpha_rb
    model = LlamaBiLDModel(
        large_model,
        small_model,
        large_tokenizer,
        small_tokenizer,
        num_small_iters=1,
        fallback_threshold=alpha_fb,
        rollback_threshold=alpha_rb,
    ).to(device)
    # Collect the training stats for 30 training traces of four temperatures
    for temperature in [0.6, 1.0, 0.3]:
        # Number of traces
        trace_count = 30
        for input_name in ["EK2", "EK1"]:
            # JSON file to save the results
            print(temperature, trace_count, input_name)
            collect_stats(input_name, temperature, trace_count)

    # Collect the stats for 5 testing traces 
    # for input_name in ["EK2", "EK1", "AK1"]:
    #     # Number of traces
    #     trace_count = 5
    #     for temperature in [0.01, 0.8, 0.3, 0.6, 1.0]:
    #         # JSON file to save the results
    #         print(temperature, trace_count, input_name)
    #         collect_stats(input_name, temperature, trace_count)