from openai import OpenAI
import json
# import base64
from tqdm import tqdm
import os
# import math
import argparse
from utils import *
from model import *
from comcts import *
import pdb
import time

def infer_comcts(args):
    data_path = args.data_path 
    if data_path.endswith('.jsonl'):
        data = read_jsonl(data_path)
    else:
        with open(data_path, 'r') as f:
            data = json.load(f)

    output_path = args.output_path
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    ans_file = open(output_path, "w")
    failed_search_path = args.output_path.replace('.jsonl', '_failed.jsonl')
    failed_search_file = open(failed_search_path, "w")

    data = get_chunk(data, args.num_chunks, args.chunk_idx)    
    # client = OpenAI(
    #     base_url=args.openai_base_url,
    #     api_key=args.openai_api_key,        
    # )

    client = None 

    activated_models, model_dict = init_model(args)

    for d in tqdm(data):
        comcts = CoMCTS(args, '', '', max_iterations=args.max_iterations)
        comcts.search(d, client, activated_models, model_dict, ans_file, failed_search_file)        

    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default='./data_sample_radvqa.json')
    parser.add_argument("--output_path", type=str, default='./out/output.json')

    parser.add_argument("--num_chunks", type=int, default=1)

    parser.add_argument("--chunk_idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.9)

    parser.add_argument("--exploration_weight", type=float, default=0.5)
    parser.add_argument("--max_iterations", type=int, default=5)
    parser.add_argument("--threshold", type=float, default=0)
    args = parser.parse_args()

    infer_comcts(args)