import argparse
import json
import os
import datasets
import pandas as pd
import pathlib
from glob import glob
import sys


def convert_data(data_path, data_type, all_data, source_key, target_key):
    # open the two files we want to write to
    with open(os.path.join(data_path, f'{data_type}.source'), 'w') as source_file, \
            open(os.path.join(data_path, f'{data_type}.target'), 'w') as target_file:
        # iterate over all items in 'all_data'
        for item in all_data:
            # get the 'source_input' and 'target_label' from each item
            source_input = item[source_key]
            target_label = item[target_key]

            # write the 'source_input' to the source file
            source_file.write(source_input + '\n')

            # convert the 'target_label' to a string: 'negative' if 0, 'positive' if 1
            target_str = 'negative' if target_label == 0 else 'positive'
            # write the 'target_label' to the target file
            target_file.write(target_str + '\n')


def convert2jsonl(input_dir):
    """
    finds all the files in the input_dir that are of the form abc.source and abc.target,
    converts them to abc.jsonl with fields `source` and `target`
    """
    source_files = glob(f"{input_dir}/*.source")
    target_files = glob(f"{input_dir}/*.target")
    for source_file in source_files:
        parent_dir = pathlib.Path(source_file).parent
        filename = pathlib.Path(source_file).name
        filename_without_extension = filename.split(".")[0]
        target_file = f"{parent_dir}/{filename_without_extension}.target"
        assert target_file in target_files
        print(f"Creating {filename_without_extension}.jsonl from {source_file} and {target_file}")
        source = pd.read_csv(source_file, header=None, sep="\t")[0].tolist()
        target = pd.read_csv(target_file, header=None, sep="\t")[0].tolist()
        output = pd.DataFrame({"question": source, "answer": target})
        output.to_json(f"{parent_dir}/{filename_without_extension}.jsonl", orient="records", lines=True)


def main(args):
    data_name = args.data_name
    dataset = datasets.load_dataset(data_name)
    data_directory = args.data_dir
    # load the train data
    if "org" in data_directory:
        all_data = dataset["train"]
        convert_data(data_directory, "train", all_data, args.src_key, args.tgt_key)
    else:
        with open(os.path.join(data_directory, "plugin_set.jsonl"), 'r') as f:
            data_dict = json.load(f)
        # get the 'all_data' list
        all_data = data_dict['all_data']
        convert_data(data_directory, "train", all_data, "source_input", "target_label")

    # load the validation data
    if "cola" in data_name:
        all_data = dataset["test"]
    else:
        all_data = dataset["validation"]
    convert_data(data_directory, "validation", all_data, args.src_key, args.tgt_key)
    convert2jsonl(data_directory)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GPT-Neo Training Script")
    parser.add_argument("--data_name", type=str, default="linxinyuan/cola")
    parser.add_argument("--data_dir", type=str, default="")
    parser.add_argument("--src_key", type=str, default="text",
                        help="Value for source key")
    parser.add_argument("--tgt_key", type=str, default="label",
                        help="Value for target key")
    # Add more arguments as needed

    args = parser.parse_args()
    main(args)