import pandas as pd
import inflect
import re
import spacy
import os
import argparse
import itertools
nlp = spacy.load("en_core_web_sm")

import random
import numpy as np

import torch
device = "cuda"


def parse_args():
    parser = argparse.ArgumentParser()

    # In-N-Out Path
    parser.add_argument('--data_dir', type=str, default='/home/prakanss/dataset_scripts_pid/cohyponym_dataset/cohyponym_images', help="Path to directory containing input images")
    parser.add_argument('--prompts_file', type=str, required=True, help="Path to file containing prompts and words of interest")
    parser.add_argument('--output_file', type=str, default='/data/drive4/srdewan/coursework/MMML/data/manual/manual.csv', help="Save path of prompts")
    parser.add_argument('--model_id', type=str, default='stabilityai/stable-diffusion-2-1', help="Generative model to use")

    args = parser.parse_args()
    return args


def construct_expanded_dataset(data_dir, prompts_file, model_id = 'stabilityai/stable-diffusion-2-1'):
    expanded_data = []

    with open(prompts_file, "r") as f:
        samples = f.readlines()

    for idx, sample in enumerate(samples):
        print(f"Processing {idx}")

        comps = sample.split(',')
        prompt = comps[0]
        words = comps[1:]
        words = ["_".join(word.split()) for word in words]

        if len(words) < 2:
            continue

        file_name = f"{prompt.replace(' ', '_')}.jpg"
        img_path = os.path.join(data_dir, file_name)

        for subset in itertools.combinations(words, 2):
            expanded_data.append([img_path, prompt, list(subset)])

    print("Number of samples = %d" % (len(expanded_data)))
    df = pd.DataFrame(expanded_data, columns = ['image path', 'caption', 'objs'])

    return df

    
def main():
    args = parse_args()

    df = construct_expanded_dataset(args.data_dir, args.prompts_file, args.model_id)
    df.to_csv(args.output_file, index = False)


if __name__ == "__main__":
    main()
