from traceback import StackSummary
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from datasets import load_dataset
import json
from collections import defaultdict
from tqdm import tqdm
from argparse import ArgumentParser
from multiprocessing import Pool
from functools import partial
import os


def filter_pos(df, filter_vals, col_name='POS_tag'):
    new_df = df.copy(deep=True)
    filter_df = new_df[new_df[col_name].isin(filter_vals)]
    return filter_df


def filter_group_words(df, group_path, group_name):
    new_df = df.copy(deep=True)

    if len(new_df) == 0:
        return []

    with open(group_path, 'r') as f:
        groups = json.load(f)

    assert group_name in groups
    del groups[group_name]
    filter_words = []

    for x in groups.values():
        filter_words += x

    # set similarity to words from other groups to 0
    new_df.loc[new_df.token.isin(filter_words), group_name] = np.nan

    return new_df.dropna()


def get_association_scores(df, group_path, thresholds, pos_filters=[]):
    if len(pos_filters) > 0:
        df = filter_pos(df, pos_filters)

    with open(group_path, 'r') as f:
        groups = json.load(f)

    association_scores = []

    for group_name in groups:
        filtered_df = filter_group_words(df, group_path, group_name)

        if len(filtered_df) > 0:
            association_score = filtered_df[filtered_df[group_name] >
                                                 thresholds[group_name]][group_name].sum() / len(filtered_df)
        else:
            association_score = 0

        association_scores.append(association_score)

    return association_scores


def process_data(id, group_path, thresholds, data_path, pos_filter):
    id = str(id)

    if os.path.exists(os.path.expanduser(data_path + id + '.csv')):
        scores = pd.read_csv(data_path + id + '.csv', index_col=0)
    else:
        with open(group_path, 'r') as f:
            groups = json.load(f)
        scores = pd.DataFrame([['none', 'none']+[0]*len(groups)], columns=['token', 'POS_tag'] + list(groups.keys()))

    #article_scores = pd.read_csv(article_path + id + '.csv', index_col=0)
    association_scores = get_association_scores(scores, group_path, thresholds, pos_filter)

    return [id] + association_scores


def group_analysis(args):
    with open(args.group_path, 'r') as f:
        groups = json.load(f)
    with open(args.threshold_path, 'r') as f:
        thresholds = json.load(f)

    cols = ['id'] + list(groups.keys())

    keyword_counts = pd.read_csv(args.article_dir + '/keyword_counts.csv')
    overall_results = []

    result_savepath = f'{args.save_dir}/association_scores_{"_".join(args.pos_filter).lower() if len(args.pos_filter) > 0 else "overall"}.csv'

    print(f'starting analysis. results will be written to {result_savepath}')

    ids = keyword_counts[(
        keyword_counts[list(groups.keys())] > 0).any(axis=1)].id
    f = partial(process_data, group_path=args.group_path, thresholds=thresholds,
                data_path=args.data_dir, pos_filter=args.pos_filter)
    with Pool(args.nworkers) as pool:
        for result in pool.imap(f, tqdm(ids, total=len(ids)), chunksize=20):
            overall_results.append(result)

    results_df = pd.DataFrame(overall_results, columns=cols)

    print(f'writing results to {result_savepath}')
    results_df.to_csv(result_savepath)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        '--data_dir', help='path to directory containing analysis csvs for summaries', type=str)
    parser.add_argument(
        '--article_dir', help='path to directory containing analysis csvs for articles', type=str)
    parser.add_argument('--group_path', help='path to group file', type=str)
    parser.add_argument(
        '--threshold_path', help='path to score thresholds for each group', type=str)
    parser.add_argument(
        '--save_dir', help='path to directory save independence test results/p-values (default: {summ_dir})', default=None, type=str)
    parser.add_argument(
        '--pos_filter', help='list of pos filters to use when performing tests', nargs='+', type=str)
    parser.add_argument(
        '--nworkers', help='number of workers to use for multiprocessing', default=10, type=int)
    args = parser.parse_args()

    if args.save_dir is None:
        args.save_dir = args.data_dir

    if args.pos_filter is None:
        args.pos_filter = []

    group_analysis(args)
