import argparse
import os

import torch
import numpy as np
import random

train_file_sizes = {"flan_v2": 100000, "cot": 100000, "dolly": 15011, "oasst1": 55668}

def parse_args():
    argparser = argparse.ArgumentParser(
        description='Script for selecting the data for training')
    argparser.add_argument('--train_file_names', type=str,
                           nargs='+', help='The name of the training files')
    argparser.add_argument('--train_files', type=str, nargs='+',
                           help='The path of the training files')
    argparser.add_argument('--target_task_names', type=str,
                           nargs='+', help='The name of the target task')
    argparser.add_argument('--output_path', type=str, help='The path to the output')
    argparser.add_argument('--input_path', type=str, help='The path to the prob info')
    argparser.add_argument('--max_samples', type=int,
                           default=None, help='The maximum number of samples')
    argparser.add_argument('--percentage', type=float, default=None,
                           help='The percentage of the data to be selected')
    argparser.add_argument('--seed', type=int)
    argparser.add_argument('--alpha', type=float)
    argparser.add_argument('--sigma', type=float, default=1.0)

    args = argparser.parse_args()
    return args


def count_lines(filename):
    with open(filename, 'r', encoding='utf-8', errors='ignore') as file:
        line_count = 0
        for line in file:
            line_count += 1
    return line_count


if __name__ == "__main__":
    args = parse_args()
    assert len(args.train_file_names) == len(args.train_files)
    assert args.percentage is not None or args.max_samples is not None
    n_train_files = len(args.train_file_names)

    if args.sigma != 1.0:
        postfix=f"sigma{args.sigma}"
    else:
        postfix=""

    np.random.seed(args.seed)

    for target_task in args.target_task_names:
        output_path = os.path.join(args.output_path, target_task)
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        input_path = os.path.join(args.input_path, target_task)

        num_samples = [train_file_sizes[name] for name in args.train_file_names]

        prob_path = os.path.join(input_path, f"prob_alpha{args.alpha}{postfix}.npy")
        prob = np.load(prob_path)
        total_samples = prob.shape[0]
        assert total_samples == sum(num_samples)
        if args.percentage is not None:
            args.max_samples = int(args.percentage * total_samples)
            data_amount_name = f"p{args.percentage}"
        else:
            data_amount_name = f"num{args.max_samples}"

        sample_times = np.random.multinomial(args.max_samples, prob)

        all_lines = []
        for i, train_file in enumerate(args.train_files):
            with open(train_file, 'r', encoding='utf-8', errors='ignore') as file:
                all_lines.extend(file.readlines()[:num_samples[i]])

        assert len(all_lines) == sample_times.shape[0]
        selected_lines = []
        for i, line in enumerate(all_lines):
            for _ in range(sample_times[i]):
                selected_lines.append(line)
        random.shuffle(selected_lines)

        with open(os.path.join(output_path, f"kdeknn{args.alpha}{postfix}_{data_amount_name}.jsonl"), 'w', encoding='utf-8', errors='ignore') as file:
            for line in selected_lines:
                try:
                    file.write(line)
                except:
                    import pdb
                    pdb.set_trace()


