import argparse
import json
import glob
import os
import torch
from metric_config import get_metric_func
from concurrent.futures import ProcessPoolExecutor
import numpy as np


def parse():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--input-dir",
        type=str,
    )
    parser.add_argument("--output-data-path", type=str)
    parser.add_argument("--input-data-path", type=str)
    parser.add_argument("--select-rate", type=float)
    parser.add_argument("--metric", type=str)
    parser.add_argument("--topk", type=int, default=-1)
    parser.add_argument("--gpus", type=int, nargs="+")
    parser.add_argument("--block-num-per-round", type=int, default=16)
    parser.add_argument("--update-llama-factory-data-info", action="store_true")
    parser.add_argument(
        "--dialog-format",
        type=str,
        default="conversations,system,human,gpt,from,value",
    )
    parser.add_argument("--lower", action="store_true")

    return parser.parse_args()


def extract_number(filepath):
    filename = os.path.basename(filepath)
    lines, block_index = filename.split("_")[-2:]
    start_line, end_line = lines.split("-")

    return int(start_line), int(end_line), int(block_index.split(".")[0])


def single_file_process(file, task_file, gpu_id, metric, topk):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    task_data = torch.load(task_file, map_location="cuda", weights_only=False)
    metric_func = get_metric_func(metric, k=topk)
    block_data = torch.load(file, map_location="cuda", weights_only=True)

    return metric_func(block_data, task_data)


def distribute_files(files, gpus, task_file, metric, topk):
    sim_data = []
    num_gpus = len(gpus)

    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                single_file_process, file, task_file, gpus[idx % num_gpus], metric, topk
            )
            for idx, file in enumerate(files)
        ]

        for future in futures:
            sim_data += future.result()

    return sim_data


def main():
    args = parse()

    with open(args.input_data_path, "r") as f:
        raw_data = json.load(f)

    # fetch similarity data
    files = glob.glob(f"{args.input_dir}/block*")
    files = sorted(files, key=lambda x: extract_number(x))

    print(args.input_dir)
    heldout_data_files = sorted(glob.glob(f"{args.input_dir}/heldout_data*"))
    # heldout_data_files = sorted(glob.glob(f"{args.input_dir}/feature*"))
    
    assert len(heldout_data_files) == 1, f"{len(heldout_data_files)} != 1"
    task_file = heldout_data_files[0]

    print(
        "============================================== Files =============================================="
    )
    for file in files:
        print(file)
    print(
        "===================================================================================================="
    )

    sim_data = []
    start = 0
    while start < len(files):
        end = min(start + args.block_num_per_round, len(files))
        sim_data += distribute_files(
            files[start:end], args.gpus, task_file, args.metric, args.topk
        )
        start = end

    assert len(sim_data) == len(raw_data), f"{len(sim_data)} != {len(raw_data)}"

    def dict_list_to_list_dict(data):
        return {k: [v[k] for v in data] for k in data[0]}

    sim_data = dict_list_to_list_dict(sim_data)

    rng = np.random.default_rng()

    # select top similar data
    for task_filename in sim_data:
        task = task_filename.split(".")[0]

        # task_sim_scores = np.array(sim_data[task_filename])
        # task_sim_scores -= np.mean(task_sim_scores)  # centering
        # task_sim_scores /= np.std(task_sim_scores)  # scaling
        # task_sim_scores += rng.gumbel(size=len(task_sim_scores))
        # select_indices = np.argpartition(
        #     -task_sim_scores, int(len(task_sim_scores) * args.select_rate)
        # )[: int(len(task_sim_scores) * args.select_rate)]

        # selected_data = [raw_data[idx] for idx in select_indices]

        # top
        selected_data = []
        sim_scores = sorted(
            enumerate(sim_data[task_filename]),
            key=lambda x: x[1],
            reverse=True if not args.lower else False,
        )
        for idx, _ in sim_scores[: int(len(sim_scores) * args.select_rate)]:
            selected_data.append(raw_data[idx])

        with open(f"{args.output_data_path}_{args.metric}_{task}.json", "w") as f:
            json.dump(selected_data, f, indent=2)

        if args.update_llama_factory_data_info:
            with open(
                os.path.join(
                    os.path.dirname(args.output_data_path), "dataset_info.json"
                ),
                "r",
            ) as f:
                data = json.load(f)

            parts = args.dialog_format.split(",")
            file_name = (
                f"{os.path.basename(args.output_data_path)}_{args.metric}_{task}.json"
            )
            data_name = file_name.replace(".json", "")

            data[data_name] = {
                "file_name": file_name,
                "formatting": "sharegpt",
                "columns": {
                    "messages": parts[0],
                },
                "tags": {
                    "role_tag": parts[4],
                    "content_tag": parts[5],
                    "user_tag": parts[2],
                    "assistant_tag": parts[3],
                    "system_tag": parts[1],
                },
            }

            with open(
                os.path.join(
                    os.path.dirname(args.output_data_path), "dataset_info.json"
                ),
                "w",
            ) as f:
                json.dump(data, f, indent=2)

            print(f"Updated data_info.json for \033[94m{data_name}\033[0m")


if __name__ == "__main__":
    main()
