from hashlib import new
import math
import pandas as pd
import numpy as np
from argparse import ArgumentParser
from nltk.tokenize import sent_tokenize


def measure_similarity(article1, article2):
    similarity = 0
    return similarity


def find_similar(article1, df, similarity_threshold):
    highest_similarity = 0
    most_similar = None

    for article2 in df.article:
        similarity = measure_similarity(article1, article2)
        if similarity > highest_similarity and similarity > similarity_threshold:
            most_similar = article2
            highest_similarity = similarity

    return most_similar


def combine_articles(article1, article2, article1_ratio=0.5, max_length=1024):
    art1_sentences = sent_tokenize(article1)
    art2_sentences = sent_tokenize(article2)

    assert len(art1_sentences) > 0 and len(art2_sentences) > 0

    new_length = min(len(art1_sentences), len(art2_sentences))

    num_art1 = math.ceil(article1_ratio * new_length)
    num_art2 = math.ceil((1-article1_ratio) * new_length)

    art1_range = len(art1_sentences) - num_art1
    art2_range = len(art2_sentences) - num_art2

    art1_start = np.random.randint(0, high=art1_range + 1)
    art2_start = np.random.randint(0, high=art2_range + 1)

    art1_chosen_sentences = art1_sentences[art1_start:art1_start+num_art1]
    art2_chosen_sentences = art2_sentences[art2_start:art2_start+num_art2]

    new_article1 = ' '.join(art1_chosen_sentences + art2_chosen_sentences)
    new_article2 = ' '.join(art2_chosen_sentences + art1_chosen_sentences)

    return new_article1, new_article2


def gen_multigroup_articles(group1_name, group2_name, base_path, save_dir, group_ratio=0.5, max_length=1024, match_topic=False):
    df1 = pd.read_csv(f'{base_path}/{group1_name}.csv', index_col=0)
    df2 = pd.read_csv(f'{base_path}/{group2_name}.csv', index_col=0)

    similarity_threshold = 0.2
    group1_ratio = group_ratio
    group2_ratio = 1-group_ratio

    new_articles1 = []
    new_articles2 = []

    if match_topic:
        for article1 in df1.articles:
            article2 = find_similar(article1, df2, similarity_threshold)
            if article2 is not None:
                new_article1, new_article2 = combine_articles(
                    article1, article2, group_ratio, max_length=max_length)
                new_articles1.append((entry1.id, entry2.id, new_article1))
                new_articles2.append((entry1.id, entry2.id, new_article2))
    else:
        max_len = min(len(df1), len(df2))
        df1_shuffle = df1.sample(frac=1).to_numpy()[:max_len, :]
        df2_shuffle = df2.sample(frac=1).to_numpy()[:max_len, :]
        pairs = zip(df1_shuffle, df2_shuffle)

        for entry1, entry2 in pairs:
            new_article1, new_article2 = combine_articles(
                entry1[-1], entry2[-1])
            new_articles1.append((entry1[0], entry2[0], new_article1))
            new_articles2.append((entry2[0], entry1[0], new_article2))

    df1 = pd.DataFrame(new_articles1, columns=[
                       'group1_id', 'group2_id', 'article'])
    df2 = pd.DataFrame(new_articles2, columns=[
                       'group1_id', 'group2_id', 'article'])

    df1.to_csv(
        f'{save_dir}/{group1_ratio:.2f}_{group1_name}_{group2_ratio:.2f}_{group2_name}.csv', index_label='id')
    df2.to_csv(
        f'{save_dir}/{group2_ratio:.2f}_{group2_name}_{group1_ratio:.2f}_{group1_name}.csv', index_label='id')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        'group1_name', help='name of first group to use in generated articles', type=str)
    parser.add_argument(
        'group2_name', help='name of second group to use in generated articles', type=str)
    parser.add_argument(
        'base_path', help='path to second set of articles to use in generated articles', type=str)
    parser.add_argument(
        'save_dir', help='directory to save newly generated article csvs to', type=str)
    parser.add_argument(
        '--group_ratio', help='percent of article that the first group should occupy (default=0.5)', type=float, default=0.5)
    parser.add_argument(
        '--max_length', help='maximum length of article to generate (default=1024)', type=int, default=1024)
    parser.add_argument(
        '--match_topic', help='If true, only articles semantic similarity will be combined (default=False)', type=bool, default=False)
    args = parser.parse_args()

    gen_multigroup_articles(args.group1_name, args.group2_name, args.base_path,
                     args.save_dir, args.group_ratio, args.max_length, args.match_topic)
