import os
import base64
import time
import json

def save_responses(args, problem, prompt_base, prompt_visualize, prompt_cot, base_response, visualize_response, cot_response, parsed_base_response, parsed_visualize_response, parsed_cot_response, index):
    if not os.path.exists(f"{args.output_dir}/{args.evaluate_model}"):
        os.makedirs(f"{args.output_dir}/{args.evaluate_model}")
    
    result = {
        "problem_index": problem["problem_index"],
        "prompt_base": prompt_base,
        "prompt_visualize": prompt_visualize,
        "prompt_cot": prompt_cot,
        "correct_label": problem["correct_label"],
        "base_response": base_response,
        "visualize_response": visualize_response,
        "cot_response": cot_response,
        "parsed_base_response": parsed_base_response,
        "parsed_visualize_response": parsed_visualize_response,
        "parsed_cot_response": parsed_cot_response
    }

    if not os.path.exists(f"{args.output_dir}/{args.evaluate_model}"):
        os.makedirs(f"{args.output_dir}/{args.evaluate_model}")

    with open(f"{args.output_dir}/{args.evaluate_model}/{index}.json", "w") as f:
        json.dump(result, f, indent=4)
    
    if args.verbose: print(f"Saved responses for problem {index}")
    return result


def parse_vlm_response(args, answer):
    if answer.strip() in ["A", "B", "C", "D"]:
        return answer.strip()
    else:
        prompt = f"""Please identify which choice (A, B, C, or D) was selected in the answer. Respond with a single letter and nothing else. 

        Answer: {answer}"""
        return make_chat_call(args, prompt)


def make_chat_call(args, prompt):
    messages = [{"role": "user", "content": prompt}]
    response = None
    wait_time = 5
    while response is None:
        try:
            response = args.evaluate_client.chat.completions.create(
                model=args.parse_model,
                messages=messages,
                max_tokens=10,
                n=1
            )
            if response.choices[0].finish_reason == "content_filter":
                print("Content filter triggered, trying again")
                response = None
            if response.choices[0].message.content.strip() in ["A", "B", "C", "D"]:
                return response.choices[0].message.content.strip()
            else:
                print(f"Invalid parsing response: {response.choices[0].message.content}")
                response = None
            
        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    
def load_dataset(args):
    # loop through each problem
    if os.path.exists(f"{args.dataset_dir}/difficulty_{args.difficulty_margin}_n_{args.n}.json"):
        with open(f"{args.dataset_dir}/difficulty_{args.difficulty_margin}_n_{args.n}.json", "r") as f:
            data = json.load(f)
    return data


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')
    

def make_vlm_call(args, template, image_path):

    image = encode_image(image_path)
    content = []

    for item in template:
        if "IMAGE" in item:
            content.append({'type': 'image_url', 
                            'image_url': {'url': f"data:image/jpeg;base64,{image}"}, 
                            'detail': 'low'})
        else:
            content.append({'type': 'text', 'text': item})
    
    messages = [{'role': 'user', 'content': content}]

    if args.evaluate_model == "gpt-4o":
        result = make_vlm_call_gpt4o(args, messages)
    else:
        print("Model not supported")
    return result


def make_vlm_call_gpt4o(args, messages):
    response = None
    wait_time = 5
    while response is None:
        try:
            response = args.evaluate_client.chat.completions.create(
                model=args.evaluate_model,
                messages=messages,
                max_tokens=args.max_evaluate_tokens,
                n=1
            )
            for r in response.choices:
                if r.finish_reason == "content_filter":
                    print("Content filter triggered, trying again")
                    response = None
                elif r.finish_reason == "length":
                    raise Exception("Max tokens exceeded")
            
        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    if args.verbose:
        print(f"Model response:\n{response.choices[0].message.content}")
    
    return response.choices[0].message.content