import argparse
import os

import torch
import numpy as np
import random

train_file_sizes = {"flan_v2": 100000, "cot": 100000, "dolly": 15011, "oasst1": 55668}

def parse_args():
    argparser = argparse.ArgumentParser(
        description='Script for selecting the data for training')
    argparser.add_argument('--train_file_names', type=str,
                           nargs='+', help='The name of the training files')
    argparser.add_argument('--train_files', type=str, nargs='+',
                           help='The path of the training files')
    argparser.add_argument('--output_path', type=str, help='The path to the output')
    argparser.add_argument('--max_samples', type=int,
                           default=None, help='The maximum number of samples')
    argparser.add_argument('--percentage', type=float, default=None,
                           help='The percentage of the data to be selected')
    argparser.add_argument('--seed', type=int)

    args = argparser.parse_args()
    return args


def count_lines(filename):
    with open(filename, 'r', encoding='utf-8', errors='ignore') as file:
        line_count = 0
        for line in file:
            line_count += 1
    return line_count


if __name__ == "__main__":
    args = parse_args()
    assert len(args.train_file_names) == len(args.train_files)
    assert args.percentage is not None or args.max_samples is not None
    n_train_files = len(args.train_file_names)

    np.random.seed(args.seed)

    output_path = os.path.join(args.output_path, 'rand')
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    num_samples = [train_file_sizes[name] for name in args.train_file_names]
    total_samples = sum(num_samples)

    if args.percentage is not None:
        args.max_samples = int(args.percentage * total_samples)
        data_amount_name = f"p{args.percentage}"
    else:
        data_amount_name = f"num{args.max_samples}"

    selected_indices = np.random.choice(sum(num_samples), size=args.max_samples, replace=False)
    print(f"Select {args.max_samples} / {sum(num_samples)}")

    sample_times = np.zeros(sum(num_samples)).astype(int)
    sample_times[selected_indices] = 1
    print(f"Sample size: {sample_times.sum()}")

    all_lines = []
    for i, train_file in enumerate(args.train_files):
        with open(train_file, 'r', encoding='utf-8', errors='ignore') as file:
            current_lines = file.readlines()
            all_lines.extend(current_lines[:num_samples[i]])
            print(f"{num_samples[i]} / {len(current_lines)}")

    assert len(all_lines) == sample_times.shape[0], f"{len(all_lines)} vs {sample_times.shape[0]}"
    selected_lines = []
    for i, line in enumerate(all_lines):
        for _ in range(sample_times[i]):
            selected_lines.append(line)
    random.shuffle(selected_lines)

    with open(os.path.join(output_path, f"rand_{data_amount_name}.jsonl"), 'w', encoding='utf-8', errors='ignore') as file:
        for line in selected_lines:
            try:
                file.write(line)
            except:
                import pdb
                pdb.set_trace()

