from PIL import Image
from pprint import pprint
import openai                 # 0.28.0 
import json 
from tqdm import tqdm 
from openai import AzureOpenAI
from dotenv import load_dotenv
client = AzureOpenAI(
            azure_endpoint =#your key, 
            api_key=#your key, 
            api_version=#your key, 
            )

from query_utils import generate_dsg
from openai_utils import openai_completion
from parse_utils import parse_tuple_output, parse_dependency_output, parse_question_output

'''
Benchmark prompts (.txt) -> DSG question generation (.json)
'''


def DSG(INPUT_TEXT_PROMPT) : 
    # Run example LLM call
    openai_completion(
        'hello, how are you doing?',
        model='gpt-35-turbo-16k-01'
    )

    id2prompts = {
        'custom_0': {
            'input': INPUT_TEXT_PROMPT,
        }
    }

    # NOTE ICL from 'query_utils/get_tifa_examples' function 
    id2tuple_outputs, id2question_outputs, id2dependency_outputs = generate_dsg(
        id2prompts,
        # you can change this method with any method that takes prompt as input and outputs LLM generation result.
        generate_fn=openai_completion
    )

    # qid2tuple = parse_tuple_output(id2tuple_outputs['custom_0']['output'])   
    qid2dependency = parse_dependency_output(id2dependency_outputs['custom_0']['output'])
    qid2question = parse_question_output(id2question_outputs['custom_0']['output'])
    return id2tuple_outputs, qid2dependency, qid2question


def update_json_file(data, filename):      # json file update 
    with open(filename, 'w') as file:
        json.dump(data, file, indent=2)


if __name__=='__main__' : 

    # Load EvalCrafter Prompts 
    all_prompts = []
    with open('/EvalCrafter/prompt700.txt', 'r') as files : 
        for line in files : 
            all_prompts.append(line.strip())

    # Specific part 
    with open('/datasets/origin/dsg_none.json', 'r') as file:
        data = json.load(file)
    count_idx_list = [item['idx'] for item in data]

    # Make DSG questions 
    all_prompts_questions = [] ; i = 0 

    for idx, INPUT_TEXT_PROMPT in tqdm(enumerate(all_prompts)) : 

        # if i > 10 : break 
        if idx not in count_idx_list : 
            continue 

        # INPUT_TEXT_PROMPT = "2 Dog and a whale, ocean adventure"
        qid2tuple, qid2dependency, qid2question = DSG(INPUT_TEXT_PROMPT)      # question dictionary 
        dsg_info = {}
        dsg_info['origin_prompt'] = INPUT_TEXT_PROMPT
        dsg_info['idx'] = idx 
        dsg_info['qid2tuple'] = qid2tuple
        dsg_info['qid2dependency'] = qid2dependency
        dsg_info['qid2question'] = qid2question

        dsg_info = {str(k): v for k, v in dsg_info.items()}               # key int -> str 
        all_prompts_questions.append(dsg_info)
        update_json_file(all_prompts_questions, '/datasets/our_dsg_depend_v2/dsg_none.json')
        i += 1 
