import pandas as pd
import subprocess
from argparse import ArgumentParser
import os
from nltk.tokenize import word_tokenize
import math
import json

def bin_articles_by_length_extractive(articles, max_range=75):
    df = pd.DataFrame(articles, columns=['article'])
    lengths = df.article.apply(lambda x: len(word_tokenize(x.replace(' [CLS] [SEP] ', ' '))))
    num_bins = math.ceil((lengths.max() - lengths.min()) / max_range)
    labels = pd.cut(lengths, num_bins, labels=[x for x in range(num_bins)])
    bins = [{'text': df[labels == x], 'min_length': lengths[labels == x].min(), 'max_length': lengths[labels == x].max()} for x in range(num_bins) if (labels == x).any()]
    return bins

def combine_files(base_save_path, bins, save_path):
    total_articles = sum([len(x['text']) for x in bins])
    df = pd.DataFrame([[-1,'',''] for x in range(total_articles)], columns=['id', 'article', 'summary'])
    offset = 0

    all_summaries = []
    all_ids = []
    all_idxs = []

    for i in range(len(bins)):
        with open(f'{base_save_path}_{i}_step-1.candidate', 'r') as f:
            summaries = summaries = [x.replace('<q>', ' ') for x in f.readlines()]
        with open(f'{base_save_path}_{i}_step-1.gold', 'r') as f:
            ids = [int(x) for x in f.readlines()]
        # with open(f'{base_save_path}_{i}_step-1.id', 'r') as f:
        #     idxs = f.readlines()
        #     idxs = [json.loads(x) for x in idxs]

        all_summaries += summaries
        all_ids += ids

    df = pd.concat([bin['text'] for bin in bins]).reset_index().rename(columns={'index':'id'})
    df = df[df.id.isin(all_ids)]
    df['summary'] = all_summaries
    df.article = df.article.apply(lambda x: x.replace(' [CLS] [SEP] ', ' '))
    df.sort_values(by='id').to_csv(save_path)

def main(args):
    id_path = os.path.abspath(args.save_path)
    save_path = os.path.abspath(args.save_path)
    data_path = os.path.abspath(args.dataset_path)

    os.chdir('../PreSumm/src')

    with open(data_path, 'r') as f:
        articles = f.readlines()
    
    bins = bin_articles_by_length_extractive(articles)

    for i, bin in enumerate(bins):
        text_path = f'{save_path}_{i}.src'
        with open(text_path, 'w+') as f:
            for text in bin['text'].article.to_list():
                f.write(text)
        with open(f'{id_path}_{i}.tgt', 'w+') as f:
            for index in bin['text'].index:
                f.write(f'{str(index)}\n')
        
        max_length = math.ceil(args.length_percent * bin['max_length'])
        min_length = math.ceil(args.length_percent * bin['min_length'])
        subprocess.run(['python3','./train.py','-mode','test_text','-task','ext', '-text_src', text_path,'-text_tgt', f'{id_path}_{i}.tgt', '-test_from','../models/bertext_cnndm_transformer.pt','-visible_gpus',str(args.device),'-max_length',str(max_length), '-min_length',str(min_length), '-result_path',f'{save_path}_{i}'])
        os.remove(text_path)

    combine_files(save_path, bins, save_path + '.csv')
    
#python3 train.py -mode test_text -task ext -text_src ~/Data/paraphrasing/untargeted/presumm_ext/articles_no_sample_10p.txt 
# -result_path ~/Data/paraphrasing/targeted/presumm_ext/presumm -min_length 50 -max_length 100 -test_from ../models/bertext_cnndm_transformer.pt -visible_gpus 4


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('dataset_path', help='path to dataset file', type=str)
    parser.add_argument('length_percent', help='percent of article length that summaries should be', type=float)
    parser.add_argument('save_path', help='path to directory summaries should be saved in', type=str)
    parser.add_argument('--device', help='which GPU to run on', type=str, default='0')
    args = parser.parse_args()

    main(args)