# script to run the entire thing: 
# python3.11 visual_matching/generate.py --verbose



import argparse
from generate_utils import generate_images_stable, save_images, generate_dataset_outpaint, generate_dataset_same_person_twice
from evaluate_utils import load_dataset, make_vlm_call, save_responses, parse_vlm_response
from rich import print
import random
import json
import os
import base64
from tqdm import tqdm
from openai import AzureOpenAI
import copy

def construct_combinations(args):
    race_ethnicity = ["asian", "black", "hispanic", "white"]
    gender = ["man", "woman"]
    age_group = ["young", "middle-aged", "old"]
    eye_color = ["brown", "blue", "green"]
    hair_color = ["brown", "blonde", "red", "black", "gray"]
    hair_length = ["long", "short"]
    hair_type = ["curly", "wavy", "straight"]

    dataset = []
    for race in race_ethnicity:
        for gen in gender:
            for age in age_group:
                for eye in eye_color:
                    if race != "white" and eye in ["blue", "green"]:
                        continue
                    for hair_col in hair_color:
                        if race != "white" and hair_col in ["red"]:
                            continue
                        for hair_len in hair_length:
                            for hair_typ in hair_type:
                                if race == "black" and hair_typ == "straight":
                                    continue
                                features = {"race": race, "gender": gen, "age": age, "eye_color": eye, "hair_color": hair_col, "hair_length": hair_len, "hair_type": hair_typ}
                                dataset.append(features)
    
    if args.verbose:
        print(f"Generated {len(dataset)} unique descriptions")
    
    return dataset


def extend_descriptions(args, sampled_descriptions):
    possible_descriptions = construct_combinations(args)
    new_descriptions = []
    for desc in possible_descriptions:
        if desc not in sampled_descriptions:
            new_descriptions.append(desc)
    
    num_descriptions_to_add = args.n - len(sampled_descriptions)
    new_descriptions_to_add = random.sample(new_descriptions, num_descriptions_to_add)
    sampled_descriptions.extend(new_descriptions_to_add)

    with open(f"{args.dataset_dir}/sampled_descriptions.json", "w") as f:
        f.write(json.dumps(sampled_descriptions))

    return sampled_descriptions


def construct_dataset(args):
    # generate/load the combinations of descriptions
    if os.path.exists(f"{args.dataset_dir}/sampled_descriptions.json"):
        with open(f"{args.dataset_dir}/sampled_descriptions.json", "r") as f:
            sampled_descriptions = json.load(f)
        
        if len(sampled_descriptions) < args.n:
            sampled_descriptions = extend_descriptions(args, sampled_descriptions)
            if args.verbose:
                print(f"Extended descriptions to {len(sampled_descriptions)}")
        
        elif args.verbose:
            print(f"Loaded {len(sampled_descriptions)} descriptions from {args.dataset_dir}/sampled_descriptions.json")
    else:
        possible_descriptions = construct_combinations(args)
        # sample args.n descriptions without replacement
        sampled_descriptions = random.sample(possible_descriptions, args.n)
        # save the sampled descriptions to a file
        with open(f"{args.dataset_dir}/sampled_descriptions.json", "w") as f:
            f.write(json.dumps(sampled_descriptions))

    # generate the dataset
    for sample_index in tqdm(range(args.start_index, min(args.end_index, len(sampled_descriptions)))):
        features = sampled_descriptions[sample_index]

        if args.similar_image_generation_method == "outpaint":
            generate_dataset_outpaint(args, features, sample_index)

        elif args.similar_image_generation_method == "same_person_twice":
            generate_dataset_same_person_twice(args, features, sample_index)

    print(f"Saved problems {args.start_index} -- {args.end_index} to {args.dataset_dir}")


def evaluate_vlm(args):
    dataset = load_dataset(args)
    results = []

    for index in tqdm(range(args.evaluate_start_index, args.evaluate_end_index)):
        if args.use_existing_results and os.path.exists(f"{args.output_dir}/results/{args.evaluate_model}/{index}.json"):
            print(f"Skipping problem {index + 1} of {len(dataset)} because results already exist")
            
            with open(f"{args.output_dir}/results/{args.evaluate_model}/{index}.json", "r") as f:
                result = json.load(f)
        
        else:
            problem = dataset[index]

            prompt_base = ["Here is an image of a person.", "IMAGE1", "Select the image that contains the same person as the person in the first image.", "IMAGES"]
            prompt_cot = copy.copy(prompt_base)
            prompt_cot[2] += " Let's think step by step."
            prompt_image_paths = [problem["base_image_path"]] + problem["image_paths"]

            base_response = make_vlm_call(args, prompt_base, prompt_image_paths)
            cot_response = make_vlm_call(args, prompt_cot, prompt_image_paths)

            parsed_base_response = parse_vlm_response(args, base_response)
            parsed_cot_response = parse_vlm_response(args, cot_response)
            
            result = save_responses(args, problem, prompt_base, prompt_cot, base_response, cot_response, parsed_base_response, parsed_cot_response, index)
        
        results.append(result)

    # calculate the accuracy
    zero_shot_accuracy = 0
    cot_accuracy = 0
    for result in results:
        if result["parsed_base_response"] == result["problem_correct_index"] + 1: # zero-indexed
            zero_shot_accuracy += 1
        if result["parsed_cot_response"] == result["problem_correct_index"] + 1:
            cot_accuracy += 1

    print(f"zero-shot accuracy: {zero_shot_accuracy / len(results)}")
    print(f"cot accuracy: {cot_accuracy / len(results)}")


def parse_args():
    parser = argparse.ArgumentParser()

    # args for dataset generation
    parser.add_argument("--n", type=int, default=500) # number of problems
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--end_index", type=int, default=500)
    parser.add_argument("--output_dir", type=str, default="visual_matching")
    parser.add_argument("--num_images_per_description", type=int, default=5)
    parser.add_argument("--generate_model", type=str, default="stable-image-ultra", choices = ['dall-e-3', 'stable-image-ultra'])
    parser.add_argument("--suffix", type=str, default="")
    parser.add_argument("--similar_image_generation_method", type=str, default="same_person_twice", choices = ["outpaint", "same_person_twice"])

    # general args
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--evaluate_model", type=str, default="gpt-4o", choices = ['gpt-4o', 'claude', 'llama'])
    parser.add_argument("--max_evaluate_tokens", type=int, default=500)
    parser.add_argument("--evaluate_start_index", type=int, default=0)
    parser.add_argument("--evaluate_end_index", type=int, default=500)
    parser.add_argument("--use_existing_results", action="store_true", default=False)
    parser.add_argument("--parse_model", type=str, default="gpt-4o", choices = ['gpt-4o-mini', 'gpt-4o', 'llama'])

    args = parser.parse_args()
    args.dataset_dir = f"{args.output_dir}/dataset/{args.generate_model}"
    return args

def main():
    args = parse_args()

    if args.generate_model == "dall-e-3":
        args.generate_api_key = "redacted"
    
    elif args.generate_model == "stable-image-ultra": # 8 cents per image in ultra
        args.generate_api_key = "redacted"

    if args.evaluate_model == "gpt-4o":
        args.evaluate_api_key = "redacted"
        api_base = "redacted"
        args.evaluate_client = AzureOpenAI(api_key = args.evaluate_api_key,  
                                           api_version = "2023-05-15",
                                           azure_endpoint = api_base)

    # step 1: construct the dataset
    # construct_dataset(args)

    # step 2: evaluate the VLM on the dataset
    evaluate_vlm(args)

if __name__ == "__main__":
    main()