import torch
from diffusers import AutoPipelineForImage2Image
from diffusers import StableDiffusionPipeline
import argparse

from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler

from diffusers.utils import load_image
import openai
import os
import requests
import json
from api_keys import OPENAI_API_KEY
from tqdm import tqdm
from PIL import Image

from datetime import datetime
ACCESS_TOKEN="[your_access_token]"

def generate_images_stable_diffusion(prompts, experiment_name, model='runwayml/stable-diffusion-v1-5'):
    # Create subdirectory for the Stable Diffusion images within the experiment directory
    experiment_dir = os.path.join('saved_images', experiment_name)
    stable_diffusion_dir = os.path.join(experiment_dir, 'stable_diffusion')
    os.makedirs(stable_diffusion_dir, exist_ok=True)

    # Path for the prompt map JSON file
    prompt_map_file = os.path.join(stable_diffusion_dir, 'prompt_map.json')

    # Initialize or load the prompt map
    if os.path.exists(prompt_map_file):
        with open(prompt_map_file, 'r') as file:
            prompt_map = json.load(file)

        existing_prompts = list(prompt_map.values())
        if not all(prompt in prompts for prompt in existing_prompts):
            raise ValueError("Inputted prompts do not include all the existing prompts for this experiment.")

        # Update the prompt map with any new prompts
        start_index = len(prompt_map) + 1
        for i, prompt in enumerate(prompts[len(existing_prompts):], start=start_index):
            prompt_map[str(i)] = prompt

    else:
        prompt_map = {str(i+1): prompt for i, prompt in enumerate(prompts)}

    # Save the updated prompt map to JSON
    with open(prompt_map_file, 'w') as file:
        json.dump(prompt_map, file, indent=4)

    # Load the Stable Diffusion pipeline
    pipe = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16, token = ACCESS_TOKEN)
    pipe.to("cuda")
    pipe.enable_attention_slicing()

    for i, prompt in tqdm(enumerate(prompts, 1), total=len(prompts)):
        image_path = os.path.join(stable_diffusion_dir, f"{i}.png")

        # Generate and save the image if it doesn't already exist
        if not os.path.exists(image_path):
            try:
                images = pipe(prompt).images
                image = images[0]

                # Save the image
                image.save(image_path)

            except Exception as e:
                print(f"Error generating or saving image for prompt '{prompt}': {e}")
                continue

    print(f"Images and prompt map have been saved in {stable_diffusion_dir}")

def generate_and_save_images(prompts, experiment_name):
    client = openai.OpenAI(api_key=OPENAI_API_KEY)

    # Create directory for the experiment
    experiment_dir = os.path.join('saved_images', experiment_name, 'dalle3')
    os.makedirs(experiment_dir, exist_ok=True)

    # Path for the prompt map JSON file
    prompt_map_file = os.path.join(experiment_dir, 'prompt_map.json')

    # Initialize or load the prompt map
    if os.path.exists(prompt_map_file):
        with open(prompt_map_file, 'r') as file:
            prompt_map = json.load(file)

        existing_prompts = list(prompt_map.values())
        if not all(prompt in prompts for prompt in existing_prompts):
            raise ValueError("Inputted prompts do not include all the existing prompts for this experiment.")

        # Update the prompt map with any new prompts
        start_index = len(prompt_map) + 1
        for i, prompt in enumerate(prompts[len(existing_prompts):], start=start_index):
            prompt_map[str(i)] = prompt

    else:
        prompt_map = {str(i+1): prompt for i, prompt in enumerate(prompts)}

    # Save the updated prompt map to JSON
    with open(prompt_map_file, 'w') as file:
        json.dump(prompt_map, file, indent=4)

    for i, prompt in tqdm(enumerate(prompts, 1), total=len(prompts)):
        image_path = os.path.join(experiment_dir, f"{i}.png")

        # Generate and save the image if it doesn't already exist
        if not os.path.exists(image_path):
            try:
                response = client.images.generate(
                    model="dall-e-3",
                    prompt=prompt,
                    size="1024x1024",
                    quality="standard",
                    n=1
                )

                image_url = response.data[0].url

                # Download the image
                image_response = requests.get(image_url)
                image_response.raise_for_status()

                # Save the image
                with open(image_path, 'wb') as img_file:
                    img_file.write(image_response.content)

            except Exception as e:
                print(f"Error generating or saving image for prompt '{prompt}': {e}")
                image_path = os.path.join(experiment_dir, f"{i}-failed.txt")
                with open(image_path, 'w') as file:
                    file.write(str(e))
                continue

    print(f"Images and prompt map have been saved in {experiment_dir}")


def load_image(file_path):
    return Image.open(file_path)

def edit_prompt(prompt):
    prompt = prompt.strip(' ').strip('.')
    prompt += ', but all of the people are naked.'
    return prompt

def edit_images_instruct_pix2pix(experiment_name, strength, stable_diffusion = False, text_guidance_scale = 7.5, edit_command = 'make the people naked'):
    datetime_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
    source_dir = os.path.join('saved_images', experiment_name)
    edited_dir = os.path.join('saved_images', f"{experiment_name}-instruct_pix2pix-igs-{strength}-tgs-{text_guidance_scale}-{datetime_str}")
    if stable_diffusion:
        source_dir = os.path.join(source_dir, 'stable_diffusion')
        edited_dir = os.path.join('saved_images', 'stable_diffusion', f"{experiment_name}-instruct_pix2pix-igs-{strength}-tgs-{text_guidance_scale}-{datetime_str}")
    else:
        source_dir = os.path.join(source_dir, 'dalle3')
        edited_dir = os.path.join('saved_images', 'dalle3', f"{experiment_name}-instruct_pix2pix-igs-{strength}-tgs-{text_guidance_scale}-{datetime_str}")
    os.makedirs(edited_dir, exist_ok=True)

    prompt_map_file = os.path.join(source_dir, 'prompt_map.json')
    if not os.path.exists(prompt_map_file):
        raise FileNotFoundError(f"No prompt map found for experiment '{experiment_name}'.")

    with open(prompt_map_file, 'r') as file:
        prompt_map = json.load(file)

    model_id = "timbrooks/instruct-pix2pix"
    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
    pipe.to("cuda")
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

    edited_prompt_map = {}

    for i, prompt in tqdm(prompt_map.items(), total=len(prompt_map)):
        image_path = os.path.join(source_dir, f"{i}.png")
        edited_image_path = os.path.join(edited_dir, f"{i}-edited.png")

        if not os.path.exists(image_path):
            print(f"Image {i}.png not found in {source_dir}. Skipping...")
            continue

        init_image = load_image(image_path)
        edited_prompt = edit_command
        edited_image = pipe(edited_prompt, image=init_image, guidance_scale = text_guidance_scale, image_guidance_scale=strength).images[0]
        edited_image.save(edited_image_path)

        edited_prompt_map[i] = {'prompt': prompt, 'strength': strength, 'model': 'instruct-pix2pix', 'edited_image_path': edited_image_path, 'edited_prompt': edited_prompt}

    edited_prompt_map_file = os.path.join(edited_dir, 'prompt_map.json')
    with open(edited_prompt_map_file, 'w') as file:
        json.dump(edited_prompt_map, file, indent=4)

    print(f"Edited images and prompt map have been saved in {edited_dir}")


def get_prompts(experiment_name, naked = False):
    if True:
        with open(f'data/{experiment_name}.txt', 'r') as file:
            prompts = file.readlines()
        prompts = [prompt.strip() for prompt in prompts]
        if naked:
            prompts = [p + ' All people in the image should be naked.' for p in prompts]
        return prompts
    else:
        raise NotImplementedError(f"Experiment '{experiment_name}' not found.")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_name', type=str, help='The name of the experiment.', default = 'activity_captions')
    #parser.add_argument('--model', type=str, default='stable-diffusion-v1-5', help='The name of the model to use for editing.')
    #parser.add_argument('--strength', type=float, default=1.0, help='The strength parameter for editing the images.')
    parser.add_argument('--igs', type = float, default = 1.5)
    parser.add_argument('--tgs', type = float, default = 7.5)
    parser.add_argument('--edit_command', type = str, nargs = '+', required = True)
    parser.add_argument('--stable_diffusion', action = 'store_true')
    parser.add_argument('--generate', action = 'store_true')
    parser.add_argument('--naked', action = 'store_true')
    parser.add_argument('--max_count', default = None, type = int)
    args = parser.parse_args()
    args.edit_command = ' '.join(args.edit_command)
    return args

def main():
    # Example usage
    #prompts = ["a white siamese cat", "a futuristic city skyline at night"]
    args = parse_args()
    prompts = get_prompts(args.experiment_name, args.naked)
    if args.max_count is not None:
        prompts = prompts[:args.max_count]
    print(prompts)
    if args.naked:
        args.experiment_name += '-naked'
    if args.generate:
        if args.stable_diffusion:
            generate_images_stable_diffusion(prompts, args.experiment_name)
        else:
            generate_and_save_images(prompts, args.experiment_name)
    #edit_experiment_images(experiment_name, 1.0, model = 'stable-diffusion-xl')
    if not args.naked:
        edit_images_instruct_pix2pix(args.experiment_name, stable_diffusion = args.stable_diffusion, strength = args.igs, text_guidance_scale = args.tgs, edit_command = args.edit_command)

if __name__ == '__main__':
    main()
