import os
import sys
import random
import pandas as pd
import argparse

import csv

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.insert(0, parent_dir)

from data.utils.process_wikiart import get_wikiart_prompts

templates = [
    "Image in the style of {}",
    "Art inspired by {}",
    "Painting in the style of {}",
    "A reproduction of art by {}",
    "A famous artwork by {}"
]

def main(num_artists, random_seed):
    
    random.seed(random_seed)

    source_file = '/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/tasks/art/art_100_concepts.csv'

    # Read the source CSV file
    df = pd.read_csv(source_file)

    df = df[df['type'] == 0]
    # Get unique artists
    unique_artists = df[df['prompt'].str.match(r'^Image in the style of \w+ \w+$')]['prompt'].str.split(' of ', expand=True)[1].unique()

    wiki_art_artists = get_wikiart_prompts()
    set_difference = set(wiki_art_artists) - set(unique_artists)
    sampled_artists = random.sample(list(set_difference), num_artists)
    print('sampled artists: ', sampled_artists)

    destination_file = f"/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/tasks/art/art_random_concepts_seed{random_seed}.csv"

    with open(destination_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        
        # Write the header
        writer.writerow(["","type","prompt","evaluation_seed"])
        
        idx = 0  # Start index
        # Loop through each concept
        for artist in sampled_artists:
            # For each template, generate 5 entries with different seeds
            for template in templates:
                for seed in range(1, 6):
                    prompt = template.format(artist)
                    writer.writerow([idx,"others",f"{prompt}",seed])
                    idx += 1

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Sample artists from a Wikiart.')
    parser.add_argument('--num_artists', type=int, help='Number of artists to sample')
    parser.add_argument('--random_seed', type=int, default=0, help='Random seed for sampling')
    
    args = parser.parse_args()

    num_artists = args.num_artists
    random_seed = args.random_seed
    
    main(num_artists, random_seed)