import torch
import math
import numpy as np
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer, AutoModel
from rich import print
import random
from pathlib import Path
from tqdm import tqdm
import json
from src.openai_utils import OpenAI


def parse_results(response, agent):
    prompt = f"""Please identify which image was selected in the answer. Respond with a single number between 1 and 5, or 0 if none of the images were selected. Do not include anything else.

    Answer: {response}
    
    For reference, the question asked was: 
    Here is an image of a person. [image] Select the image that contains the same person as the person in the first image. [five candidate images]
    """
    response = agent.complete([prompt])
    return response


if __name__ == "__main__":
    results_path = "redacted/results/InternVL26B"
    agent = OpenAI()
    for idx in range(20):
        output_file = Path(results_path, str(idx), "output.json")
        save_file = Path(results_path, str(idx), "parsed_output.json")
        with open(output_file, "r") as f:
            output = json.load(f)
        
        zero_shot_response = output[0]["zero_shot_response"]
        cot_response = output[0]["cot_response"]
        correct_index = output[0]["correct_image"]
        
        parser_zero_shot_response = parse_results(zero_shot_response, agent)
        parser_cot_response = parse_results(cot_response, agent)
        
        if int(parser_zero_shot_response) == int(correct_index):
            zero_shot_correct = 1
        else:
            zero_shot_correct = 0
        
        if int(parser_cot_response) == int(correct_index):
            cot_correct = 1
        else:
            cot_correct = 0
            
        result = {
            "zero_shot_response": zero_shot_response,
            "cot_response": cot_response,
            "parser_zero_shot_response": parser_zero_shot_response,
            "parser_cot_response": parser_cot_response,
            "zero_shot_correct": zero_shot_correct,
            "cot_correct": cot_correct,
            "correct_index": correct_index
        }
        
        with open(save_file, "w") as f:
            json.dump(result, f, indent=4)
        
        
    