import pandas as pd
import inflect
import re
import spacy
import os
import argparse
import itertools
nlp = spacy.load("en_core_web_sm")

import random
import numpy as np

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
device = "cuda"


def parse_args():
    parser = argparse.ArgumentParser()

    # In-N-Out Path
    parser.add_argument('--prompts_file', type=str, required=True, help="Path to file containing prompts and words of interest")
    parser.add_argument('--output_dir', type=str, default='/data/drive4/srdewan/coursework/MMML/data', help="Path to directory where to save generated images")
    parser.add_argument('--output_file', type=str, default='/data/drive4/srdewan/coursework/MMML/data/manual/manual.csv', help="Save path of prompts")
    parser.add_argument('--seed', type=int, default=42, help="Seed to use for generating images")
    parser.add_argument('--model_id', type=str, default='stabilityai/stable-diffusion-2-1', help="Generative model to use")

    args = parser.parse_args()
    return args


def construct_expanded_dataset(out_dir, prompts_file, model_id = 'stabilityai/stable-diffusion-2-1'):
    expanded_data = []
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype = torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to(device)
    # generator = torch.Generator(device).manual_seed(seed)

    with open(prompts_file, "r") as f:
        samples = f.readlines()

    for idx, sample in enumerate(samples):
        print(f"Processing {idx}")

        comps = sample.split(',')
        prompt = comps[0]
        words = comps[1:]
        words = ["".join(word.split()) for word in words]
        img_path = os.path.join(out_dir, f"{prompt.replace(' ', '_')}.jpg")

        if len(words) < 2:
            continue

        image = pipe(prompt = prompt, num_inference_steps = 50)["images"][0]
        image.save(img_path)

        for subset in itertools.combinations(words, 2):
            expanded_data.append([img_path, prompt, list(subset)])

    print("Number of samples = %d" % (len(expanded_data)))
    df = pd.DataFrame(expanded_data, columns = ['image path', 'caption', 'objs'])

    return df

    
def main():
    args = parse_args()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    df = construct_expanded_dataset(args.output_dir, args.prompts_file, args.model_id)
    df.to_csv(args.output_file, index = False)


if __name__ == "__main__":
    main()
