import argparse
import os

import torch
import json
import jsonlines


def parse_args():
    argparser = argparse.ArgumentParser(
        description='Script for selecting the data for training')
    argparser.add_argument('--train_file_names', type=str, nargs='+',
                           help='The name of the file')
    argparser.add_argument('--score_files', type=str, nargs='+',
                           help='The path of the influence function score of the training file')
    argparser.add_argument('--train_files', type=str, nargs='+',
                           help='The path of the training file that corresponds to the score file')
    argparser.add_argument('--target_task_names', type=str,
                           nargs='+', help='The name of the target task')
    argparser.add_argument('--output_path', type=str,
                           default="selected_data", help='The path to the output')
    argparser.add_argument('--max_samples', type=int,
                           default=None, help='The maximum number of samples')
    argparser.add_argument('--percentage', type=float, default=None,
                           help='The percentage of the data to be selected')

    args = argparser.parse_args()

    return args


def count_lines(filename):
    with open(filename, 'r', encoding='utf-8', errors='ignore') as file:
        line_count = 0
        for line in file:
            line_count += 1
    return line_count


if __name__ == "__main__":
    args = parse_args()

    # assert len(args.train_file_names) == len(args.train_files)
    assert args.percentage is not None or args.max_samples is not None
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_train_files = len(args.train_file_names)

    for target_task in args.target_task_names:
        output_path = os.path.join(args.output_path, target_task)

        # score_paths = [os.path.join(
        #     output_path, f"{task_name}_influence_score.pt") for task_name in args.train_file_names]
        score_files = args.score_files
        num_samples = []
        for score_file in score_files:
            num_samples.append(
                len(torch.load(score_file, map_location=device)))
        
        cumsum_num_samples = torch.cumsum(torch.tensor(num_samples), dim=0)

        total_samples = sum(num_samples)
        if args.percentage is not None:
            args.max_samples = int(args.percentage * total_samples)
            data_amount_name = f"p{args.percentage}"
        else:
            data_amount_name = f"num{args.max_samples}"

        all_scores = []
        for score_path, train_file in zip(score_files, args.train_files):
            score = torch.load(score_path, map_location=device)
            all_scores.append(score)
        all_scores = torch.cat(all_scores, dim=0)

        # sort the scores and output the corresponding data index
        file_specific_index = torch.cat(
            [torch.arange(line_num) for line_num in num_samples]).to(device)
        data_from = torch.cat([torch.ones(line_num, dtype=torch.long)
                              * i for i, line_num in enumerate(num_samples)]).to(device)
        sorted_scores, sorted_index = torch.sort(
            all_scores, dim=0, descending=True)

        sorted_score_file = os.path.join(output_path, f"sorted.csv")

        sorted_index = file_specific_index[sorted_index]
        data_from = data_from[sorted_index]

        if not os.path.exists(sorted_score_file):
            with open(sorted_score_file, 'w', encoding='utf-8') as file:
                file.write("file name, index, score\n")
                for score, index, name in zip(sorted_scores, sorted_index, data_from):
                    file.write(
                        f"{args.train_file_names[name.item()]}, {index.item()}, {round(score.item(), 6)}\n")

        topk_scores, topk_indices = torch.topk(
            all_scores.float(), args.max_samples, dim=0, largest=True)

        all_lines = []
        for i, train_file in enumerate(args.train_files):
            with open(train_file, 'r', encoding='utf-8', errors='ignore') as file:
                for item in jsonlines.Reader(file):
                    all_lines.append(item)
        all_lines = all_lines[:num_samples[i]]

        final_index_list = sorted_index[:args.max_samples].tolist()
        final_data_from = data_from[:args.max_samples].tolist()

        # print(final_data_from)
        with open(os.path.join(output_path, f"top_{data_amount_name}.jsonl"), 'a', encoding='utf-8', errors='ignore') as file:
            for data_from, index in zip(final_data_from, final_index_list):
                data = json.dumps(all_lines[index]) + '\n'
                file.write(data)
                # except:
                #     import pdb
                #     pdb.set_trace()
