import os
import torch
import gc
from tifascore import (
    get_llama2_pipeline,
    get_llama2_question_and_answers,
    UnifiedQAModel,
    VQAModel,
    filter_question_and_answers,
    tifa_score_single,
)

# Load models
print("Loading models...")
pipeline = get_llama2_pipeline("tifa-benchmark/llama2_tifa_question_generation")
unifiedqa_model = UnifiedQAModel("allenai/unifiedqa-v2-t5-large-1363200")
vqa_model = VQAModel("blip-base")
# vqa_model = VQAModel("git-large")

# Directory where subdirectories (each named as prompt) are stored
root_image_dir = "./outputs/tifa_image_test"
results = []

# Loop through each prompt folder
for prompt_folder in os.listdir(root_image_dir):
    prompt_path = os.path.join(root_image_dir, prompt_folder)
    if not os.path.isdir(prompt_path):
        continue  # Skip if not a directory

    prompt = prompt_folder.replace("_", " ")
    print(f"\nProcessing prompt: {prompt}")

    try:
        with torch.no_grad():
            # Generate questions for this prompt
            llama2_questions = get_llama2_question_and_answers(pipeline, prompt)
            # filtered_questions = filter_question_and_answers(unifiedqa_model, llama2_questions)
            filtered_questions = llama2_questions

            # Loop through all images in the prompt folder
            for filename in os.listdir(prompt_path):
                if filename.endswith(".png"):
                    img_path = os.path.join(prompt_path, filename)

                    try:
                        result = tifa_score_single(vqa_model, filtered_questions, img_path)
                        tifa_score = result["tifa_score"]
                        results.append((prompt, filename, tifa_score))
                        print(f"  Processed: {filename} | TIFA: {tifa_score:.2f}")
                    except Exception as e:
                        print(f"  Error processing {filename}: {e}")

            # Cleanup to free memory
            del filtered_questions, result, llama2_questions
            torch.cuda.empty_cache()
            gc.collect()

    except Exception as e:
        print(f"Error generating questions for {prompt}: {e}")
        torch.cuda.empty_cache()
        gc.collect()

# Save all TIFA scores to a file
with open("tifa_scores.txt", "w") as f:
    f.write("Prompt\tImage\tTIFA Score\n")
    for prompt, filename, score in results:
        f.write(f"{prompt}\t{filename}\t{score:.4f}\n")

# Compute and display overall statistics
if results:
    scores = [score for _, _, score in results]
    avg_score = sum(scores) / len(scores)
    max_score = max(scores)
    min_score = min(scores)

    print(f"\nTotal images processed: {len(scores)}")
    print(f"Average TIFA score: {avg_score:.4f}")
    print(f"Highest TIFA score: {max_score:.4f}")
    print(f"Lowest TIFA score: {min_score:.4f}")
else:
    print("No valid TIFA scores computed.")
