import argparse
import json
import os
from dataclasses import asdict

from src.api_bank_data_conversion import api_bank_process_data
from src.tool_ace_data_conversion import split_raw_file
from src.utility import (
    dataset_tool_ace,
    dataset_api_bank,
    setup_logging,
    dataset_when2call,
)
from src.when2call_data_conversion import when2call_convert

random_seed = 666


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_files", required=True, nargs="+", help="Input files")
    parser.add_argument("--output_dir", default=".", help="Output dir")
    parser.add_argument(
        "--output_file_suffix",
        type=str,
        help="Suffix to append to input file for generating output file name",
    )
    parser.add_argument("--data_set", type=str, help="Data set name")
    parser.add_argument("--data_type", type=str, help="Data type name")
    parser.add_argument(
        "--filter_data", type=bool, default=False, help="Filter data in when2call"
    )

    parsed_args = parser.parse_args()
    return parsed_args


def process_api_bank(input_file: str, output_file: str):
    with open(input_file, "r") as fin:
        records = json.load(fin)
    new_records = api_bank_process_data(records)
    with open(output_file, "w") as fout:
        json.dump([asdict(r) for r in new_records], fout, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    """Convert dataset specific files to unified format."""
    args = parse_args()
    setup_logging(write_stdout=True)
    print(f"Running {os.path.basename(__file__)} with args: {args}")
    for input_file in args.input_files:
        if args.output_file_suffix:
            output_file = os.path.join(
                args.output_dir,
                os.path.basename(input_file) + "." + args.output_file_suffix,
            )
        else:
            output_file = os.path.join(args.output_dir, os.path.basename(input_file))

        os.makedirs(args.output_dir, exist_ok=True)
        if args.data_set == dataset_api_bank:
            process_api_bank(input_file, str(output_file))
        elif args.data_set == dataset_tool_ace:
            split_raw_file(
                input_file, f"{str(output_file)}.train", f"{str(output_file)}.test"
            )
        elif args.data_set == dataset_when2call:
            when2call_convert(
                input_file, f"{str(output_file)}", args.filter_data, args.data_type
            )
        else:
            raise RuntimeError(f"Unknown dataset:{args.data_set}")
