import time

import numpy as np
import datasets
import torch
import random
import argparse
import os
import json
import custom_datasets
from model import load_tokenizer, load_model

def stats_str(data):
    if type(data) == dict:
        mean_orig = np.mean([len(v.split()) for v in data['original']])
        mean_samp = np.mean([len(v.split()) for v in data['sampled']])
        return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled).'
    else:
        mean_orig = np.mean([len(v['original'].split()) for v in data])
        mean_samp = np.mean([len(v['sampled'].split()) for v in data])
        mean_perturb_orig = np.mean([np.mean([len(p.split()) for p in v['perturbed_original']]) for v in data])
        mean_perturb_samp = np.mean([np.mean([len(p.split()) for p in v['perturbed_sampled']]) for v in data])
        return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled), {mean_perturb_orig:.0f} words (perturb original), {mean_perturb_samp:.0f} words (perturb sampled).'

def save_data(output_file, args, data):
    # write args to file
    args_file = f"{output_file}.args.json"
    with open(args_file, "w") as fout:
        json.dump(args, fout, indent=4)
        print(f"Args written into {args_file}")

    # write the data to a json file in the save folder
    data_file = f"{output_file}.raw_data.json"
    with open(data_file, "w") as fout:
        json.dump(data, fout, indent=4)
        print(f"Raw data written into {data_file}: {stats_str(data)}")


def load_data(input_file):
    # load args from file
    args_file = f"{input_file}.args.json"
    with open(args_file, "r") as fin:
        args = json.load(fin)
        print(f"Args loaded from {args_file}")

    # load the data from file
    data_file = f"{input_file}.raw_data.json"
    with open(data_file, "r") as fin:
        data = json.load(fin)
        print(f"Raw data loaded from {data_file}: {stats_str(data)}")

    return args, data

def convert_data(input_file, output_file, max_words):
    def _reduce(text):
        lines = []
        nwords = 0
        for line in text.split('\n'):
            if nwords >= max_words:
                break
            words = line.split()
            words = words[:max_words - nwords]
            lines.append(' '.join(words))
            nwords += len(words)
        return '\n'.join(lines)

    args, data = load_data(input_file)
    if type(data) == dict:
        data['original'] = [_reduce(x) for x in data['original']]
        data['sampled'] = [_reduce(x) for x in data['sampled']]
    else:
        for item in data:
            item['original'] = _reduce(item['original'])
            item['sampled'] = _reduce(item['sampled'])
            item['perturbed_original'] = [_reduce(x) for x in item['perturbed_original']]
            item['perturbed_sampled'] = [_reduce(x) for x in item['perturbed_sampled']]

    save_data(output_file, args, data)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, default="./exp_gpt3to4/data/")
    parser.add_argument('--output_path', type=str, default="./exp_maxlen150/data/")
    parser.add_argument('--max_words', type=int, default=150)
    args = parser.parse_args()

    import glob
    import os.path as path

    for file_name in glob.glob(f'{args.input_path}/*.raw_data.json'):
        print(file_name)
        file_name = path.basename(file_name).replace('.raw_data.json', '')
        convert_data(path.join(args.input_path, file_name), path.join(args.output_path, file_name), args.max_words)
