import pandas as pd
import math
from nltk.tokenize import sent_tokenize, word_tokenize
from tqdm import tqdm
from glob import glob
import subprocess
tqdm.pandas()

group_categories = {
    'gender':['men','women'],
    'race':['black', 'white', 'asian', 'hispanic'],
    'religion':['judaism','islam','christianity']
}

def bin_articles_by_sentence_count(articles, max_range=5):
    article_sentences = articles.apply(sent_tokenize)
    sent_lengths = article_sentences.apply(len)
    word_lengths = articles.apply(lambda x: len(word_tokenize(x)))
    num_bins = math.ceil((sent_lengths.max() - sent_lengths.min()) / max_range)
    labels = pd.cut(sent_lengths, num_bins, labels=range(num_bins))
    bins = [{
                'id':article_sentences[labels == x].index.to_list(), 
                'text': article_sentences[labels == x], 
                'min_length': sent_lengths[labels == x].min(), 
                'max_length': sent_lengths[labels == x].max(),
                'min_word_length': word_lengths[labels == x].min(),
                'max_word_length': word_lengths[labels == x].max()
            } for x in range(num_bins) if (labels == x).any()]

    return bins


def process(datapath, savepath, colname='article'):
    df = pd.read_csv(datapath)
    #bins = bin_articles_by_sentence_count(df.article)

    df.article = df[colname].apply(sent_tokenize)
    df = df.rename(columns={'article':'text'})
    df['summary'] = [[] for _ in range(len(df))]
    df.to_json(savepath, orient='records', lines=True)

if __name__ == '__main__':
    # for cat in group_categories:
    #     for g in group_categories[cat]:
    #         datapath = f'../data/synthetic_data/articles/single_group/{cat}/{g}.csv'
    #         savepath = f'../data/synthetic_data/matchsum/single_group/{g}/{g}.jsonl'
    #         process(datapath, savepath)
    
    # for ratio in [0.1,0.5,0.9]:
    #     for cat in group_categories:
    #         for g1 in group_categories[cat]:
    #             for g2 in group_categories[cat]:
    #                 if g1 == g2:
    #                     continue
    #                 datapath = f'../data/synthetic_data/articles/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.csv'
    #                 savepath = f'../data/synthetic_data/matchsum/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.jsonl'
    #                 process(datapath, savepath)
    i = 1
    for g in sum(group_categories.values(), []):
        datapath = f'../data/synthetic_data/matchsum/recursive/{i-1}_{g}.csv'
        savepath = f'../data/synthetic_data/matchsum/recursive/{i}_{g}.jsonl'
        process(datapath, savepath, colname='summary')