import argparse
import os
import sys
import torch
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.insert(0, parent_dir)
from openai import OpenAI
from tqdm import tqdm

from data.utils.process_laion import get_laion_prompts
from data.utils.process_celeba import get_celeba_prompts
from data.utils.process_wikiart import get_wikiart_prompts
from transformers import CLIPProcessor, CLIPModel

model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

concept_prompt_template = \
"""
What is the abstract concept that is being changed amongst the set of captions below:

{}

Please supply the list of values of this abstract concept as your response.

"""

BLOCK_SIZE = 1

def generatePrompts(unlabeled_dir, output_dir):

    dataset = unlabeled_dir.split('/')[-1]
    if dataset == 'laion400m':
        unlabeled_prompts = get_laion_prompts(unlabeled_dir)
        for i, prompt_block in enumerate(unlabeled_prompts):
            print(i)
            prompt_block = [str(item).replace("'", '"') for item in prompt_block.tolist()]
            with open(output_dir + '/unlabeled_prompts_{}.txt'.format(i), 'w') as file:
                for item in tqdm(prompt_block):
                    file.write(item + '\n')
    elif dataset == 'Celeba_HQ_dialog':
        unlabeled_prompts = get_celeba_prompts(unlabeled_dir)
        prompt_blocks = [unlabeled_prompts[i:i+BLOCK_SIZE] for i in range(0, len(unlabeled_prompts), BLOCK_SIZE)]
        for i, prompt_block in enumerate(prompt_blocks):
            with open(output_dir + f'/unlabeled_prompts_{i}.txt', 'w') as file:
                for item in tqdm(prompt_block):
                    file.write(str(item) + '\n')
    elif dataset == 'wikiart':
        unlabeled_prompts = get_wikiart_prompts()
        prompt_blocks = [unlabeled_prompts[i:i+BLOCK_SIZE] for i in range(0, len(unlabeled_prompts), BLOCK_SIZE)]
        for i, prompt_block in enumerate(prompt_blocks):
            with open(output_dir + f'/unlabeled_prompts_{i}.txt', 'w') as file:
                for item in tqdm(prompt_block):
                    file.write(str(item) + '\n')
    elif dataset == 'cifar100':
        unlabeled_prompts = [
            'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle',
            'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar',
            'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile',
            'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
            'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster',
            'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid',
            'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
            'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew',
            'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
            'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle',
            'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
        ]
        prompt_blocks = [unlabeled_prompts[i:i+BLOCK_SIZE] for i in range(0, len(unlabeled_prompts), BLOCK_SIZE)]
        for i, prompt_block in enumerate(prompt_blocks):
            with open(output_dir + f'/unlabeled_prompts_{i}.txt', 'w') as file:
                for item in tqdm(prompt_block):
                    file.write(str(item) + '\n')   
    else:
        raise ValueError(f"Dataset {dataset} not supported")


if __name__=='__main__':
    parser = argparse.ArgumentParser(
                    prog = 'generateImages',
                    description = 'Generate Images using Diffusers Code')
    parser.add_argument('--unlabeled_dir', help='path to csv file with prompts', type=str, 
                        required=True)
    parser.add_argument('--output_dir', help='path to csv file with prompts', type=str, 
                        required=True)
    args = parser.parse_args()

    generatePrompts(args.unlabeled_dir, args.output_dir)