from transformers import pipeline, set_seed, AutoTokenizer
from datasets import load_dataset, Dataset
from tqdm import tqdm
from nltk.tokenize import word_tokenize, sent_tokenize
import pandas as pd
import json
import itertools
import argparse


def gen_excluded_ids(all_groups, group_name, tokenizer):
    group_cats = {
        'gender': ['men', 'women'],
        'race': ['black', 'white', 'asian', 'hispanic'],
        'religion': ['judaism', 'christianity', 'islam']
    }

    if group_name in group_cats['gender']:
        target_cat = 'gender'
    elif group_name in group_cats['race']:
        target_cat = 'race'
    else:
        target_cat = 'religion'

    excluded_ids = []
    for g in all_groups:
        # only exclude words from groups within the same category
        if g != group_name and g in group_cats[target_cat]:
            excluded_ids += [tokenizer(x.lower(),
                                       add_special_tokens=False).input_ids for x in all_groups[g]]
            excluded_ids += [tokenizer(' ' + x.lower(),
                                       add_special_tokens=False).input_ids for x in all_groups[g]]
            excluded_ids += [tokenizer(x.title(),
                                       add_special_tokens=False).input_ids for x in all_groups[g]]
            excluded_ids += [tokenizer(' ' + x.title(),
                                       add_special_tokens=False).input_ids for x in all_groups[g]]
    return excluded_ids


def main(args):
    group_mappings = {
        'men': 'American_actors',
        'women': 'American_actresses',
        'black': 'African_Americans',
        'white': 'European_Americans',
        'hispanic': 'Hispanic_and_Latino_Americans',
        'asian': 'Asian_Americans',
        'judaism': 'judaism',
        'christianity': 'christianity',
        'islam': 'islam'
    }

    assert len(args.groups) > 0 and len(
        set(args.groups) - set(group_mappings.keys())) == 0

    with open('../data/wordlists/groups_cased.json', 'r') as f:
        groups = json.load(f)
    with open('../data/wordlists/gender_prompt.json', 'r') as f:
        gender_prompts = json.load(f)
    with open('../data/wordlists/race_prompt.json', 'r') as f:
        race_prompts = json.load(f)
    with open('../data/wordlists/religious_ideology_prompt.json', 'r') as f:
        religion_prompts = json.load(f)
    prompts = {**gender_prompts, **race_prompts, **religion_prompts}

    print('loading model')
    generator = pipeline('text-generation', model='gpt2', device=0)
    # Need to do this to allow batching
    generator.tokenizer.add_special_tokens(
        {'pad_token': generator.tokenizer.eos_token})
    set_seed(42)
    print('loading tokenizer')
    tokenizer = AutoTokenizer.from_pretrained('gpt2')

    print('loading prompts')

    print('generating corpus')
    for g in args.groups:
        print(f'starting generation for {g}')
        final_prompts = [x for x in list(itertools.chain.from_iterable(
            prompts[group_mappings[g]].values()))[:100]]
        excluded_ids = gen_excluded_ids(groups, g, tokenizer)
        generated = []

        for out in tqdm(generator(final_prompts, min_length=400, max_length=1024, num_return_sequences=1, top_k=40, top_p=0.95, no_repeat_ngram=3, bad_words_ids=excluded_ids, batch_size=5)):
            generated += out

        corpus = [x['generated_text'].replace('\n', ' ') for x in generated]
        df = pd.DataFrame(corpus, columns=['article'])
        df.to_csv(f'../data/synthetic_data/articles/{g}.csv', index_label='id')


if __name__ == '__main__':
    tqdm.pandas()
    parser = argparse.ArgumentParser()
    parser.add_argument('--groups', nargs='+', type=str)
    args = parser.parse_args()

    main(args)
