import os
import base64
import time
import json
import random

def save_responses(args, problem, prompt_base, prompt_cot, base_response, cot_response, parsed_base_response, parsed_cot_response, index):
    if not os.path.exists(f"{args.output_dir}/results/{args.evaluate_model}"):
        os.makedirs(f"{args.output_dir}/results/{args.evaluate_model}")
    
    result = {
        "problem_index": problem["problem_index"],
        "prompt_base": prompt_base,
        "prompt_cot": prompt_cot,
        "problem_correct_index": problem["correct_index"],
        "base_response": base_response,
        "cot_response": cot_response,
        "parsed_base_response": parsed_base_response,
        "parsed_cot_response": parsed_cot_response
    }

    with open(f"{args.output_dir}/results/{args.evaluate_model}/{index}.json", "w") as f:
        json.dump(result, f, indent=4)
    
    if args.verbose: print(f"Saved responses for problem {index}")
    return result

def parse_vlm_response(args, answer):
    prompt = f"""Please identify which image was selected in the answer. Respond with a single number between 1 and 5, or 0 if none of the images were selected. Do not include anything else.

    Answer: {answer}
    
    For reference, the question asked was: 
    Here is an image of a person. [image] Select the image that contains the same person as the person in the first image. [five candidate images]
    """
    return make_chat_call(args, prompt)


def make_chat_call(args, prompt):
    messages = [{"role": "user", "content": prompt}]
    response = None
    wait_time = 5
    while response is None:
        try:
            response = args.evaluate_client.chat.completions.create(
                model=args.parse_model,
                messages=messages,
                max_tokens=10,
                n=1
            )
            if response.choices[0].finish_reason == "content_filter":
                print("Content filter triggered, trying again")
                response = None
            if response.choices[0].message.content.isdigit():
                return int(response.choices[0].message.content)
            else:
                print(f"Invalid parsing response: {response.choices[0].message.content}")
                response = None
            
        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    
def load_dataset(args):
    dataset = []
    for path in range(args.evaluate_start_index, args.evaluate_end_index):
        # valid range is 0 to args.n

        # loop through each problem
        with open(f"{args.dataset_dir}/{path}/metadata.json", "r") as f:
            metadata = json.load(f)

        # get the path of the base and correct images
        base_image_path = f"{args.dataset_dir}/{path}/{metadata['double_images'][0]}"
        correct_image_path = f"{args.dataset_dir}/{path}/{metadata['double_images'][1]}"
        
        # load the incorrect images
        incorrect_image_paths = []
        for filename in metadata["single_images"]:
            incorrect_image_paths.append(f"{args.dataset_dir}/{path}/{filename}")

        # put the images together
        inserting_index = random.randint(0, len(incorrect_image_paths))
        image_paths = incorrect_image_paths[:inserting_index] + [correct_image_path] + incorrect_image_paths[inserting_index:]

        problem = {
            "problem_index": int(path),
            "base_image_path": base_image_path,
            "image_paths": image_paths,
            "correct_index": inserting_index
        }
        
        dataset.append(problem)
    
    return dataset


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')
    

def make_vlm_call(args, template, image_paths):

    images = [encode_image(image_path) for image_path in image_paths]
    content = []

    for item in template:
        if item == "IMAGES":
            for image in images:
                content.append({'type': 'image_url', 
                                'image_url': {'url': f"data:image/jpeg;base64,{image}"}, 
                                'detail': 'low'})
        elif "IMAGE" in item:
            content.append({'type': 'image_url', 
                            'image_url': {'url': f"data:image/jpeg;base64,{images[0]}"}, 
                            'detail': 'low'})
            images.pop(0)
        else:
            content.append({'type': 'text', 'text': item})
    
    if args.verbose: print([(item['type'] if item['type'] == 'image_url' else item['text']) for item in content])

    messages = [{'role': 'user', 'content': content}]

    if args.evaluate_model == "gpt-4o":
        result = make_vlm_call_gpt4o(args, messages)
    else:
        print("Model not supported")
    return result

def make_vlm_call_gpt4o(args, messages):
    response = None
    wait_time = 5
    while response is None:
        try:
            response = args.evaluate_client.chat.completions.create(
                model=args.evaluate_model,
                messages=messages,
                max_tokens=args.max_evaluate_tokens,
                n=1
            ) # TODO: ok to not use headers and requests API?
            for r in response.choices:
                if r.finish_reason == "content_filter":
                    print("Content filter triggered, trying again")
                    response = None
            
        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    return response.choices[0].message.content