from generator.concept.concept_generator import *
from tqdm import tqdm
import argparse

generator = ConceptGenerator()

def match_sents(sent_path):
    result = []
    with open(sent_path) as f:
        sents = f.read().split("\n")

    for index, sent in enumerate(tqdm(sents)):
        if generator.check_availability(sent):
            generated_sentence = generator.cor_generate(sent)
            result.append({"original_sentence": sent, "output_sentence": generated_sentence})
    return result

parser = argparse.ArgumentParser()
parser.add_argument('--train_path', default='wiki.train.raw')
parser.add_argument('--valid_path', default='wiki.valid.raw')
parser.add_argument('--save_path', default="datasets/concept_order_recovering/")

args = parser.parse_args()

train_result = match_sents(sent_path=args.train_path)
valid_result = match_sents(sent_path=args.valid_path)

train_sources = []
train_targets = []
for line in train_result:
    train_sources.append("correct the order of the following sentence : " + line["output_sentence"])
    train_targets.append(line["original_sentence"])

valid_sources = []
valid_targets = []
for line in valid_result:
    valid_sources.append("correct the order of the following sentence : " + line["output_sentence"])
    valid_targets.append(line["original_sentence"])

with open(args.save_path + 'train.source','w',encoding='utf8') as f:
    for line in train_sources:
        f.write(line+'\n')

with open(args.save_path + 'train.target','w',encoding='utf8') as f:
    for line in train_sources:
        f.write(line+'\n')

with open(args.save_path + 'valid.source','w',encoding='utf8') as f:
    for line in valid_sources:
        f.write(line+'\n')

with open(args.save_path + 'valid.target','w',encoding='utf8') as f:
    for line in valid_sources:
        f.write(line+'\n')