from gensim.models import KeyedVectors
from collections import defaultdict
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.tag import pos_tag, pos_tag_sents
from glob import glob
from tqdm import tqdm
from pathlib import Path
from multiprocessing import Pool
from functools import partial
from datasets import load_dataset
import json
import string
import pandas as pd
import numpy as np
import os
import argparse


pronouns = {'he',
            'him',
            'himself',
            'his',
            "he's",
            'her',
            'hers',
            'herself',
            'she',
            "she's"}

content_tags = {'NOUN', 'PRON', 'VERB', 'ADJ', 'ADV'}

en_stopwords = set(stopwords.words('english')).union({"n't"}) - pronouns

print(f'Loading w2v model (takes a minute)')
w2v = KeyedVectors.load_word2vec_format(
    '../data/GoogleNews-vectors-negative300.bin.gz', binary=True)


def make_df(dicts, ids, labels):
    entries = []

    for i in range(len(ids)):
        entries += [[ids[i], x[0]] + x[1].tolist() for x in dicts[i].items()]

    df = pd.DataFrame.from_records(entries, columns=labels)
    return df


def get_similarity(text, group, pos_filter=None):
    tokens = list(
        filter(lambda x: x not in string.punctuation, word_tokenize(text)))
    tagged_tokens = pos_tag(tokens, tagset='universal')
    similarities = []

    for token, tag in tagged_tokens:
        if token not in w2v or token.lower() in en_stopwords or tag not in content_tags:
            continue
        if pos_filter is not None and tag not in pos_filter:
            continue

        similarity = np.mean([w2v.similarity(token, x) for x in groups[group]])
        similarities.append(similarity)

    if len(similarities) > 0:
        similarity = np.mean(similarities)
    else:
        similarity = 0

    return similarity


def process_text(text, id, savedir=None, groups=None, save=True):
    token_tags = []
    group_similarities = []
    df = []
    keyword_counts = [0] * len(groups)
    means = {}
    stdev = {}

    # tokenize the text and pos tag the text using universal tagset to get a general idea of the text
    tokens = list(
        filter(lambda x: x not in string.punctuation, word_tokenize(text)))
    tagged_tokens = pos_tag(tokens, tagset='universal')

    for token, tag in tagged_tokens:
        # skip the token if it's not in the model, a stopword, or doesn't have a POS we're interested in
        if token not in w2v or token.lower() in en_stopwords or tag not in content_tags:
            continue
        token_tags.append((token, tag))
        entry = []
        # find similarity to each group
        for i, g in enumerate(groups):
            # similarity is the mean similarity to each group term, as described by sap et al. (2018)
            similarity = sum([w2v.similarity(token, x)
                             for x in groups[g]]) / len(groups[g])
            # don't normalize by the number of tokens here so that we can normalize by individual POS in analysis function
            entry += [similarity]

            if token in groups[g] or token.lower() in groups[g] or token.title() in groups[g]:
                keyword_counts[i] += 1

        group_similarities.append(entry)

    if len(group_similarities) == 0:
        group_similarities = np.zeros(len(groups))
    else:
        group_similarities = np.asarray(group_similarities)

    token_tags = np.asarray(token_tags)

    if save and len(token_tags) > 0:
        df = pd.DataFrame(np.concatenate(
            [token_tags, group_similarities], axis=-1), columns=['token', 'POS_tag'] + list(groups.keys()))
        df.to_csv(Path(savedir + '/' + str(id) + '.csv'))
    elif not save:
        return token_tags, group_similarities

    if len(df) > 0:
        # find the mean similarity for each POS tag and normalize by the total tokens with that tag so we don't favor longer articles
        for tag in content_tags:
            idx = np.where((token_tags[:, 1] == tag))[0]
            if len(idx) == 0:
                means[tag] = np.zeros(len(groups))
                stdev[tag] = np.zeros(len(groups))
            else:
                means[tag] = group_similarities[idx].mean(axis=0)
                stdev[tag] = group_similarities[idx].std(axis=0)
        means['overall'] = group_similarities.mean(axis=0)
        stdev['overall'] = group_similarities.std(axis=0)
    else:
        for tag in content_tags:
            means[tag] = np.zeros(len(groups))
            stdev[tag] = np.zeros(len(groups))
        means['overall'] = np.zeros(len(groups))
        stdev['overall'] = np.zeros(len(groups))

    # also calculate the overall group similarity

    return means, stdev, keyword_counts


def get_groups(args):
    if args.datapath is None:
        dataset = load_dataset(
            args.data_name, args.data_version, split=args.data_split)
        df = pd.DataFrame(dataset).dropna()
    else:
        #df = pd.read_csv(args.datapath).dropna().reset_index(drop=True)
        df = pd.read_csv(args.datapath, index_col=0, lineterminator='\n').dropna().reset_index(drop=True)

    with open(args.group_path) as f:
        groups = json.load(f)

    means = []
    keyword_counts = []
    # dataframe column names
    labels = ['id', 'pos_tag'] + list(groups.keys())

    # arguments to pass to starmap
    process_args = zip(df[args.col_name], df.id)
    process = partial(process_text, savedir=args.savedir, groups=groups)

    print(f'Collecting statistics for {len(df)} texts from {args.datapath}:')

    with Pool(args.nworkers) as pool:
        output = list(pool.starmap(process, tqdm(
            process_args, total=len(df)), chunksize=20))

    means = [x[0] for x in output]
    stdevs = [x[1] for x in output]
    keyword_counts = [x[2] for x in output]

    mean_df = make_df(means, df.id, labels)
    stdev_df = make_df(stdevs, df.id, labels)
    keyword_count_df = pd.DataFrame(
        keyword_counts, index=df.id, columns=list(groups.keys()))

    print(f'writing mean similarities to: {args.savedir}/means.csv')
    mean_df.to_csv(Path(args.savedir + '/means.csv'))
    print(
        f'writing similarity standard deviations to: {args.savedir}/stdevs.csv')
    stdev_df.to_csv(Path(args.savedir + '/stdevs.csv'))
    print(f'writing keyword counts to: {args.savedir}/keyword_counts.csv')
    keyword_count_df.to_csv(Path(args.savedir + '/keyword_counts.csv'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_name', help='Name of dataset on huggingface hub', type=str)
    parser.add_argument(
        '--data_version', help='dataset version on huggingface hub', type=str)
    parser.add_argument('--data_split', help='Data split from huggingface hub',
                        type=str, default='train+validation+test')
    parser.add_argument(
        '--datapath', help='Path to csv holding articles and summaries', type=str)
    parser.add_argument('--savedir', help='Directory to save generated csvs in',
                        default='../data/group_analyses/', type=str)
    parser.add_argument(
        '--col_name', help='Name of column containing text to analyze (Default: summary).', default='summary', type=str)
    parser.add_argument(
        '--group_path', help='Path to file containing group keywords', type=str)
    parser.add_argument(
        '--nworkers', help='Number of processes to use (default: 10)', type=int, default=10)
    args = parser.parse_args()

    if not os.path.isdir(args.savedir):
        os.mkdir(args.savedir)

    get_groups(args)
