from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv('OPENAI_API_KEY.env'))

from PIL import Image
import requests
from io import BytesIO
from openai import OpenAI
import os

client = OpenAI()

def get_image(image_path, mask_path, output_path, prompt):
  response = client.images.edit(
    model="dall-e-2",
    image=open(image_path, "rb"),
    mask=open(mask_path, "rb"),
    prompt=prompt,
    n=1,
    size="1024x1024"
  )

  image_url = response.data[0].url

  # download the image from the url


  response = requests.get(image_url)
  img = Image.open(BytesIO(response.content))
  img.save(output_path)

  return


def main():
  personalities = ['gullible', 'hard-working','helpful', 'honest', 'imaginative', 'impolite', 'inconsiderate', 'indecisive', 'inflexible', 'insecure', 'intelligent', 'jealous', 'kind', 'lazy', 'loyal', 'mean', 'meticulous', 'modest', 'moody', 'narrow-minded', 'nasty', 'optimistic', 'outgoing', 'patient',  'pessimistic', 'pretentious', 'quick-tempered', 'rebellious', 'reliable', 'rude', 'self-centered', 'selfish', 'sensible', 'sensitive', 'sincere', 'sociable', 'stubborn', 'sulky', 'sympathetic', 'tactless', 'thoughtful', 'trustworthy', 'unpleasant', 'unpleasant', 'unreliable', 'vain'] 
  for personality in personalities:
    for root, dirs, files in os.walk(".\inputs"):
      for file in files:
        if file.endswith('.jpg') or file.endswith('.png'):
          input_path = os.path.join(root, file)
          mask_path = os.path.join(root.replace('inputs', 'masks'), file)
          output_dir = os.path.join(root.replace('inputs', 'outputs'), personality)
          if not os.path.exists(output_dir):
            os.makedirs(output_dir)
          output_path = os.path.join(output_dir, file)
          prompt = f"a photo of a {personality}"
          get_image(input_path, mask_path, output_path, prompt)
          print(f"Generated image for {personality} from {input_path}.")
    print(f"Generated all images for {personality}.")
  return

if __name__ == "__main__":
  main()