import pandas as pd
from glob import glob

groups = ['men', 'women', 'black', 'white', 'hispanic', 'asian', 'judaism', 'islam', 'christianity']
group_categories = {
    'gender':['men', 'women'],
    'race':['black', 'white', 'hispanic', 'asian'],
    'religion':['judaism', 'islam', 'christianity']
}
ratios = [0.1, 0.5, 0.9]

for ratio in ratios:
    for cat in group_categories:
        for g1 in group_categories[cat]:
            for g2 in group_categories[cat]:
                if g1 == g2:
                    continue
                print(f'{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}')
                filename = f'../data/synthetic_data/matchsum/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.jsonl'
                summary_dir = f'../data/synthetic_data/matchsum/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}/result/MatchSum_cnndm_roberta.ckpt/dec/*.dec'
                df = pd.read_json(filename, lines=True)
                summaries = []
                
                for p in glob(summary_dir):
                    id = int(p.split('/')[-1][:-4])
                    with open(p) as f:
                        summary = ' '.join(f.readlines())
                    df.loc[df.id == id, 'summary'] = summary

                df.text = df.text.apply(lambda x: ' '.join(x))
                df = df.rename(columns={'text': 'article'})
                df.to_csv(f'../data/synthetic_data/matchsum/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.csv')

# Single group version
# for g in groups:
#     df = pd.read_json(f'../data/synthetic_data/matchsum/single_group/{g}/{g}.jsonl', lines=True)
#     summaries = []
#     for p in glob(f'../data/synthetic_data/matchsum/single_group/{g}/result/MatchSum_cnndm_roberta.ckpt/dec/*.dec'):
#         id = int(p.split('/')[-1][:-4])
#         with open(p) as f:
#             summary = ' '.join(f.readlines())
#             df.loc[df.id == id, 'summary'] = summary
    
#     df.text = df.text.apply(lambda x: ' '.join(x))
#     df = df.rename(columns={'text':'article'})
#     df.to_csv(f'../data/synthetic_data/matchsum/single_group/{g}/result/{g}.csv')

