import os
import json
from constants import TASK_DIR, VISION_QUESTIONS, MODELS
from api_functions import make_api_request
from concurrent.futures import ThreadPoolExecutor, as_completed

def load_wcst_data(trial_number):
    data_dir = os.path.join(TASK_DIR["WCST"], f"trial{trial_number}", "cards.json")
    with open(data_dir, 'r') as f:
        return json.load(f)

def process_card(model, card_data):
    card_num = card_data['trialNumber']
    print(f"    Processing card_{card_num}...")
    image_path = card_data['image']
    prompt = VISION_QUESTIONS["Overall"]
    response, tokens = make_api_request(model, prompt, image_path)
    return {
        "image": card_data['image'],
        "prompt": prompt,
        "response": response,
        "tokens": tokens,
    }

def test_vision_accuracy(modelname):
    model = modelname
    wcst_data = load_wcst_data(1)  # Assuming we're using trial 1 data
    print(f"Processing {model}...")
    
    model_results = []
    max_workers = 20
    
    # Sort the cards by trial number to ensure order
    sorted_cards = sorted(wcst_data.values(), key=lambda x: x['trialNumber'])
    num_cards = len(sorted_cards)
    # num_cards = 4
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_card = {executor.submit(process_card, model, card): card for card in sorted_cards[:num_cards]}
        
        completed = 0
        total = num_cards
        for future in as_completed(future_to_card):
            model_results.append(future.result())
            completed += 1
            if completed % max_workers == 0 or completed == total:
                print(f"Completed {completed} of {total} tasks")
    
    # Sort the results by trial number
    model_results.sort(key=lambda x: int(x['image'].split('/')[-1].split('_')[0]))
    
    print(f"{model} completed.\n")
    return model_results
    
def load_existing_results(file_path):
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            return json.load(f)
    return {}

def save_results(results, file_path):
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    with open(file_path, 'w') as f:
        json.dump(results, f, indent=2)

if __name__ == "__main__":
    file_path = './output/vision_accuracy_results.json'
    
    # results = {}
    results = load_existing_results(file_path)
    
    # model_results = test_vision_accuracy("Gemini-1.5 Pro")
    # results["Gemini-1.5 Pro"] = model_results
    
    # model_results = test_vision_accuracy("Claude-3.5 Sonnet")
    # results["Claude-3.5 Sonnet"] = model_results
    
    model_results = test_vision_accuracy("GPT-4o")
    results["GPT-4o"] = model_results
    
    # for model in MODELS:
    #     model_results = test_vision_accuracy(model)
    #     results[model] = model_results
    
    # Save results to a JSON file
    save_results(results, file_path)