import os
import sys
import random
import pandas as pd
import argparse

import csv

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.insert(0, parent_dir)

from data.utils.process_cifar100 import get_cifar100_prompts

templates = [
    "a photo of {}",
]

def main(num_artists, random_seed):
    
    random.seed(random_seed)

    unique_artists = ['automobile', 'bird', 'ship']

    wiki_art_artists = get_cifar100_prompts()
    set_difference = set(wiki_art_artists) - set(unique_artists)
    sampled_artists = random.sample(list(set_difference), num_artists)
    print('sampled artists: ', sampled_artists)

    destination_file = f"/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/tasks/object/object_random_concepts_seed{random_seed}.csv"

    with open(destination_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        
        # Write the header
        writer.writerow(["","type","prompt","evaluation_seed"])
        
        idx = 0  # Start index
        # Loop through each concept
        for artist in sampled_artists:
            # For each template, generate 5 entries with different seeds
            for template in templates:
                for seed in range(1, 6):
                    prompt = template.format(artist)
                    writer.writerow([idx,"others",f"{prompt}",seed])
                    idx += 1

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Sample artists from a Wikiart.')
    parser.add_argument('--num_artists', type=int, help='Number of artists to sample')
    parser.add_argument('--random_seed', type=int, default=0, help='Random seed for sampling')
    
    args = parser.parse_args()

    num_artists = args.num_artists
    random_seed = args.random_seed
    
    main(num_artists, random_seed)