import pandas as pd
from nltk import sent_tokenize
import numpy as np
import seaborn as sns
from tqdm import tqdm
from utils import *
import matplotlib.pyplot as plt
import os
from datasets import load_dataset

tqdm.pandas()

def gen_loc_csv(data_path, save_path):
    df = pd.read_csv(data_path, index_col=0).dropna()
    #df = pd.DataFrame(load_dataset('cnn_dailymail', '3.0.0', split='train+validation+test'))
    print('tokenizing sentences')
    df.article = df.article.progress_apply(sent_tokenize)
    df.summary = df.summary.progress_apply(sent_tokenize)
    df['article_length'] = df.article.apply(len)
    print('getting locations')
    idxs = df.progress_apply(lambda x: get_summary_indices(x.article, x.summary, 2, 0.1), axis=1).tolist()
    location_df = pd.DataFrame(zip(df.id.tolist(), df.article_length.tolist(), idxs), columns=['id', 'article_length', 'idxs'])
    percents = location_df.progress_apply(lambda x: (x.idxs)/x.article_length, axis=1)
    location_df['idx_quantile'] = percents
    location_df.to_parquet(save_path)
    return percents

def plot_loc_hist(data_path, save_dir, summarizer, bins=50, prefix=''):
    percents = gen_loc_csv(data_path, f'{save_dir}/{summarizer}_loc.parquet')
    percents = np.concatenate(percents.tolist())
    percents_df = pd.DataFrame(percents, columns = ['idx_quantile'])
    plt.clf()
    sns.histplot(data=percents_df, x='idx_quantile', bins=bins, stat='probability')
    plt.savefig(f'{save_dir}/{prefix}{summarizer}_loc_hist.png')


if __name__ == '__main__':
    summarizers = ['textrank', 'bart', 'pegasus', 'gpt3', 'azure']
    percents = [0.1,0.5,0.9]
    ratios = [0.1,0.5,0.9]
    groups = ['men', 'women', 'black', 'white', 'hispanic', 'asian', 'islam', 'judaism', 'christianity']
    group_categories = {
        'gender':['men', 'women'],
        'race':['black', 'white', 'hispanic', 'asian'],
        'religion':['islam', 'judaism', 'christianity']
    }

    # for s in summarizers:
    #     for cat in group_categories:
    #         if s in ['gpt3', 'azure'] and cat != 'gender':
    #             continue
    #         for ratio in ratios:
    #             for p in percents:
    #                 for g1 in group_categories[cat]:
    #                     for g2 in group_categories[cat]:
    #                         if g1 == g2:
    #                             continue
    #                         data_path = f'../data/synthetic_data/{s}/multigroup/{p:.1f}_{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.csv'
    #                         save_dir = f'../paper/figs/location/multigroup/{s}/'
    #                         plot_loc_hist(data_path, save_dir, s, prefix=f'{p:.1f}_{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}_')
    
    for s in ['matchsum', 'presumm']:
        for cat in group_categories:
            if s in ['gpt3', 'azure'] and cat != 'gender':
                continue
            for ratio in ratios:
                for g1 in group_categories[cat]:
                    for g2 in group_categories[cat]:
                        if g1 == g2:
                            continue
                        data_path = f'../data/synthetic_data/{s}/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.csv'
                        save_dir = f'../paper/figs/location/multigroup/{s}/'
                        plot_loc_hist(data_path, save_dir, s, prefix=f'{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}_')

    # for s in summarizers:
    #     for p in percents:
    #         for g in groups:
    #             data_path = f'../data/synthetic_data/{s}/single_group/{p}_{g}.csv'
    #             save_dir = f'../paper/figs/location/single_group/{s}/'
    #             plot_loc_hist(data_path, save_dir, s, prefix=f'{p}_{g}_')

    # for s in ['presumm', 'matchsum']:
    #     for g in groups:
    #         data_path = f'../data/synthetic_data/{s}/single_group/{g}/{g}.csv'
    #         save_dir = f'../paper/figs/location/single_group/{s}/'
    #         plot_loc_hist(data_path, save_dir, s, prefix=f'{p}_{g}_')

    #plot_loc_hist('cnn_dailymail', '/home/user/user/location_analysis', 'human_summaries')
    #for s in summarizers:
        #data_path = os.path.expanduser(f'~/user/{s}.csv')
        #save_dir = os.path.expanduser(f'~/user/location_analysis')
        #print(f'generating plot for {s}!')
        #plot_loc_hist(data_path, save_dir, s)