import time
import base64
import argparse
import requests
import random
import os
from openai import OpenAI
from PIL import Image
import io
import json


def generate_dataset_outpaint(args, features, sample_index):
    # first, generate the 5 different people
    description = f"{features['age']} {features['race']} {features['gender']}, with {features['eye_color']} eyes and {features['hair_length']} {features['hair_color']} {features['hair_type']} hair."
    if args.verbose: print(f"description {sample_index}: {description}")

    prompt = f"Generate a realistic image of a unique person with the following description: {description}"
    
    if args.generate_model == "dall-e-3":
        client = OpenAI(api_key=args.api_key)
        images_b64s = generate_images_dalle(args, client, args.generate_model, prompt)
        images_pngs = [base64.b64decode(image_b64).decode('utf-8') for image_b64 in images_b64s]

    elif args.generate_model == "stable-image-ultra":
        images_pngs = generate_images_stable(args, prompt)
        images_b64s = [(base64.b64encode(image_png)).decode('utf-8') for image_png in images_pngs]
    
    save_images(args, images_b64s, images_pngs, args.output_dir, sample_index)

    # next, generate one modified image of one random person
    similar_image = generate_similar_image(args, images_pngs, description, args.suffix, sample_index)


def generate_dataset_same_person_twice(args, features, sample_index):
    # first, generate the double image
    description = f"{features['age']} {features['race']} {features['gender']}, with {features['eye_color']} eyes and {features['hair_length']} {features['hair_color']} {features['hair_type']} hair."
    if args.verbose: print(f"description {sample_index}: {description}")

    double_prompt = f"Generate two realistic images of the same person, one on the left and one on the right. The person should have the following description: {description}"
    double_pngs = generate_images_stable(args, double_prompt, aspect_ratio="21:9")
    if args.verbose: print("generated images")
    save_image(args, "", double_pngs[0], f"{args.dataset_dir}/{sample_index}", f"double_full")

    # Segment the image down the middle
    left_image_data, right_image_data = segment_image_vertically(double_pngs[0])
    if args.verbose: print("segmented image")
    
    double_b64s = [(base64.b64encode(image_png)).decode('utf-8') for image_png in [left_image_data, right_image_data]]
    save_image(args, double_b64s[0], left_image_data, f"{args.dataset_dir}/{sample_index}", f"double_left")
    save_image(args, double_b64s[1], right_image_data, f"{args.dataset_dir}/{sample_index}", f"double_right")
    if args.verbose: print("saved images")

    # next, generate the remaining 4 images of unique people in the same pose as the right image
    prompt = f"Generate an image of a unique person with the same pose and style as the image provided. The person should have the following description: {description}"
    pngs = control_structure_image_stable(args, right_image_data, prompt, args.num_images_per_description - 1)
    b64s = [(base64.b64encode(image_png)).decode('utf-8') for image_png in pngs]
    for i in range(args.num_images_per_description - 1):
        save_image(args, b64s[i], pngs[i], f"{args.dataset_dir}/{sample_index}", f"single_{i}")

    # save the metadata
    with open(f"{args.dataset_dir}/{sample_index}/metadata.json", "w") as f:
        f.write(json.dumps({
            "description": description, 
            "features": features,
            "sample_index": sample_index,
            "generate_model": args.generate_model,
            "double_prompt": double_prompt,
            "double_image": "double_full.png",
            "double_images": ["double_left.png", "double_right.png"],
            "single_prompt": prompt,
            "num_images_per_description": args.num_images_per_description,
            "similar_image_generation_method": args.similar_image_generation_method,
            "single_images": [f"single_{i}.png" for i in range(args.num_images_per_description - 1)],
        }))


def generate_images_dalle(args, client, model, text):
    wait_time = 5
    responses = []
    for i in range(args.num_images_per_description):
        response = None
        while response is None:
            try:
                response = client.images.generate(model=model, \
                                                prompt=text, \
                                                n=1, \
                                                response_format="b64_json", \
                                                size="1024x1024", \
                                                style="natural") 
                # n must be 1 for dalle-3 
                # size must be at least 1024x1024 for dalle-3 
            except Exception as e:
                print(f'Caught exception {e}.')
                print(f'Waiting {wait_time} seconds.')
                time.sleep(wait_time)
        responses.append(response.data[0].b64_json)

    return responses


def control_structure_image_stable(args, image_png, prompt_text, num_images):
    wait_time = 5
    responses = []
    for i in range(num_images):
        response = None
        while response is None:
            response = requests.post(
                f"https://api.stability.ai/v2beta/stable-image/control/structure",
                headers={
                    "authorization": f"Bearer {args.generate_api_key}",
                    "accept": "image/*"
                },
                files={
                    "image": image_png
                },
                data={
                    "prompt": prompt_text,
                    "control_strength": 0.4
                },
            )
            if response.status_code == 200:
                responses.append(response.content)
            else:
                print(f"Error: {str(response.json())}")
                print(f'Waiting {wait_time} seconds.')
                time.sleep(wait_time)
                response = None

    return responses


def generate_images_stable(args, text, aspect_ratio="1:1"):
    wait_time = 5
    responses = []
    for i in range(args.num_images_per_description):
        response = None
        while response is None:
            response = requests.post(
                f"https://api.stability.ai/v2beta/stable-image/generate/ultra",
                headers={
                    "authorization": f"Bearer {args.generate_api_key}",
                    "accept": "image/*"
                },
                files={"none": ''},
                data={
                    "prompt": text,
                    "output_format": "png",
                    "aspect_ratio": aspect_ratio
                },
            )

            if response.status_code == 200:
                responses.append(response.content)
            else:
                print(f"Error: {str(response.json())}")
                print(f'Waiting {wait_time} seconds.')
                time.sleep(wait_time)
                response = None

    return responses


def generate_similar_image_stable(args, image_png, prompt_text):
    # things I've tried: 
    # - control/structure with control_strength = 0.7
    # - control/structure with control_strength = 1 (best performing, but still details that can tell that they are not the same person, e.g., ear shape)
    # - control/style with control_strength = 0.7 (not a valid arg)
    # - control/style with fidelity = 1
    # - edit/outpaint with creativity = 0.5, and 256 pixels on all sides

    wait_time = 5
    response = None
    while response is None:
        if args.verbose: print("starting request")
        response = requests.post(
            f"https://api.stability.ai/v2beta/stable-image/edit/outpaint",
            headers={
                "authorization": f"Bearer {args.api_key}",
                "accept": "image/*"
            },
            files={
                "image": image_png
            },
            data={
                "prompt": prompt_text,
                "creativity": 0.5,
                "output_format": "png",
                "left": 256
            },
        )
        if args.verbose: print("received response")

        if response.status_code != 200:
            print(f"Error: {str(response.json())}")
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)
            response = None

    return response.content


def generate_similar_image(args, images_pngs, description, suffix, sample_index):
    index = random.randint(0, args.num_images_per_description - 1)
    prompt_image = images_pngs[index]

    prompt_text = "fill in background" # f"Generate a realistic image of the same person in the image. The original person's description was: {description}"
    similar_image = generate_similar_image_stable(args, prompt_image, prompt_text)
    similar_image_b64 = (base64.b64encode(similar_image)).decode('utf-8')

    dir_path = f"{args.dataset_dir}/{sample_index}"
    
    save_image(args, similar_image_b64, similar_image, dir_path, f"{index}_{suffix}")

    return similar_image


def save_image(args, image_b64, image_png, dir_path, suffix):
    path = f"{dir_path}/{suffix}"
    # suffix is the index of the image and optional "similar"

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    with open(f"{path}.json", "w") as f:
        f.write(image_b64)
    
    with open(f"{path}.png", "wb") as f:
        f.write(image_png)

    if args.verbose: print(f"saved image at {path}")


def save_images(args, images_b64s, images_pngs, output_dir, sample_index):
    dir_path = f"{output_dir}/{sample_index}"
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    for i in range(args.num_images_per_description):
        save_image(args, images_b64s[i], images_pngs[i], dir_path, str(i))


def segment_image_vertically(image_data):
    image = Image.open(io.BytesIO(image_data))
    width, height = image.size

    left_image_leftmost = width // (7 * 4)
    left_image_rightmost = left_image_leftmost + height

    right_image_leftmost = left_image_rightmost + (width // (7 * 2))
    right_image_rightmost = right_image_leftmost + height

    left_image = image.crop((left_image_leftmost, 0, left_image_rightmost, height))
    right_image = image.crop((right_image_leftmost, 0, right_image_rightmost, height))

    left_image_bytes = io.BytesIO()
    right_image_bytes = io.BytesIO()

    left_image.save(left_image_bytes, format='PNG')
    right_image.save(right_image_bytes, format='PNG')

    return left_image_bytes.getvalue(), right_image_bytes.getvalue()