import os
import pdb
import sys
import time
import json
import torch
import queue
import logging
from tqdm import tqdm



root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)
import argparse
from src.instruction_generate import demon_prompt_generate, task_instruction_generate
from src.main_model_thread import MainModelThread
from src.model_load import load_model
from src.assist_model_thread import AssistModelThread
from src.common_vocabulary import CommonVocabulary
from src.transfer_matrix import ProbabilityTransferMatrix


def main():
    start_time = time.time()  

    
    parser = argparse.ArgumentParser(description='Process some files.')
    
    parser.add_argument('--config', help='the name of the file to process')
    parser.add_argument('--learning_rate', '-lr', default=0.0, type=float, required=False, help="learning_rate")
    parser.add_argument('--anchor_point_count', '-apc', default=32000, type=int, required=False,
                        help='anchor_point_count')
    parser.add_argument('--learning_epochs_nums', '-len', default=5, type=int, required=False,
                        help='learning_epochs_nums')
    parser.add_argument('--result_save_dir', '-rsd', default="./", type=str, required=False, help='result_save_dir')
    parser.add_argument('--run_mode', '-rm', default="dev", type=str, required=False, help='result_save_dir')
    parser.add_argument('--logits_processor_mode', '-lpm', default="based_on_probility_transfer_logits_processor",
                        type=str,
                        required=False,
                        help='logits_processor_mode')
    parser.add_argument('--device_compute', '-dp', default="cuda:0", type=str, required=False,
                        help='device_compute')
    parser.add_argument('--device0', '-d0', default="cuda:0", type=str, required=False,
                        help='device0')
    parser.add_argument('--device1', '-d1', default="cuda:1", type=str, required=False,
                        help='device1')
    parser.add_argument('--device2', '-d2', default="cuda:2", type=str, required=False,
                        help='device2')
    parser.add_argument('--device3', '-d3', default="cuda:3", type=str, required=False,
                        help='device3')
    parser.add_argument('--device4', '-d4', default="cuda:4", type=str, required=False,
                        help='device4')
    parser.add_argument('--device5', '-d5', default="cuda:5", type=str, required=False,
                        help='device5')
    parser.add_argument('--device6', '-d6', default="cuda:6", type=str, required=False,
                        help='device6')
    parser.add_argument('--device7', '-d7', default="cuda:7", type=str, required=False,
                        help='device7')
    parser.add_argument('--device8', '-d8', default="cuda:7", type=str, required=False,
                        help='device')

    parser.add_argument('--main_temperature', '-mt', default=100, type=float, required=False,
                        help='main_temperature')
    parser.add_argument('--assist_temperature', '-at', default=100, type=float, required=False,
                        help='assist_temperature')
    parser.add_argument('--min_prob', default=0.8, type=float, required=False,
                        help='min_prob')
    parser.add_argument('--max_prob', default=0.9, type=float, required=False,
                        help='max_prob')

    # 解析命令行参数
    args = parser.parse_args()

    # 使用指定的文件名来操作文件
    with open(args.config, 'r', encoding='utf-8') as f:
        config_json = json.load(f)

    main_model_path = config_json["model_path"]["main_model_path"]
    assist_model1_path = config_json["model_path"]["assist_model1_path"]
    assist_model2_path = config_json["model_path"]["assist_model2_path"]
    # assist_model3_path = config_json["model_path"]["assist_model3_path"]

    main_model_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"]["main_model_path"]
    assist_model1_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
        "assist_model1_path"]
    assist_model2_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
        "assist_model2_path"]
    # assist_model3_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
    #     "assist_model3_path"]

    dev_file_path = config_json["file_path"]["dev_file_path"]
    test_file_path = config_json["file_path"]["test_file_path"]

    demon_file_path = config_json["file_path"]["demon_file_path"]

    instruction = config_json["prompt_template"]["instruction"]
    instruction_parameter = config_json["prompt_template"]["instruction_parameter"]
    main_model_system_template = config_json["prompt_template"]["main_model_system_template"]
    assist_model1_system_template = config_json["prompt_template"]["assist_model1_system_template"]
    assist_model2_system_template = config_json["prompt_template"]["assist_model2_system_template"]
    # assist_model3_system_template = config_json["prompt_template"]["assist_model3_system_template"]
    max_new_tokens = config_json["run_parameter"]["max_new_tokens"]
    start_index = config_json["run_parameter"]["start_index"]
    end_index = config_json["run_parameter"]["end_index"]
    demon_parameter = config_json["prompt_template"]["demon_parameter"]

    result_process_parameter = config_json["result_process_parameter"]
    try:
        early_stop_string_list = result_process_parameter["early_stop_string_list"]
    except:
        early_stop_string_list = None
    result_save_dir = args.result_save_dir
    logits_processor_mode = args.logits_processor_mode
    if os.path.isdir(result_save_dir):
        pass
    else:
        os.makedirs(result_save_dir)

    anchor_point_count = args.anchor_point_count
    learning_rate = args.learning_rate
    learning_epochs_nums = args.learning_epochs_nums
    run_mode = args.run_mode

    device_compute = args.device_compute
    device0 = args.device0
    device1 = args.device1
    device2 = args.device2
    device3 = args.device3

    main_temperature = args.main_temperature
    assist_temperature = args.assist_temperature

    input_file_path = dev_file_path if run_mode == "dev" else test_file_path

    logging.basicConfig(filename=os.path.join(result_save_dir,
                                              f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.process.log'),
                        level=logging.DEBUG)
    logging.info(f'\n【config_json:】{config_json}')
    logging.info(f'\n【result_save_dir:】{result_save_dir}')
    logging.info(f'\n【anchor_point_count:】{anchor_point_count}')
    logging.info(f'\n【learning_rate:】{learning_rate}')
    logging.info(f'\n【learning_epochs_nums:】{learning_epochs_nums}')

    main_model_probability_transfer_matrix = torch.load(main_model_probability_transfer_matrix_path,
                                                        map_location=device0)
    assist_model_probability_transfer_matrix1 = torch.load(assist_model1_probability_transfer_matrix_path,
                                                           map_location=device1)
    assist_model_probability_transfer_matrix2 = torch.load(assist_model2_probability_transfer_matrix_path,
                                                           map_location=device2)
    # assist_model_probability_transfer_matrix3 = torch.load(assist_model3_probability_transfer_matrix_path,
    #                                                        map_location=device3)

    main_model, main_model_tokenizer, main_model_streamer = load_model(main_model_path, "auto")

    assist_model1, assist_model_tokenizer1, _ = load_model(assist_model1_path, "auto")
    assist_model2, assist_model_tokenizer2, _ = load_model(assist_model2_path, "auto")
    # assist_model3, assist_model_tokenizer3, _ = load_model(assist_model3_path, "auto")

    common_vocabulary = CommonVocabulary(main_model_tokenizer, assist_model_tokenizer1,assist_model_tokenizer2)
    #
    common_vocab_list = common_vocabulary.get_common_vocab_list(*common_vocabulary.vocabs)
    #
    probability_transfer_matrix = ProbabilityTransferMatrix()
    anchor_point_list = probability_transfer_matrix.get_anchor_point_list(common_vocab_list=common_vocab_list)

    main_model_probability_transfer_matrix_list = [main_model_probability_transfer_matrix
                                                   ]
    assist_model_probability_transfer_matrix_list = [assist_model_probability_transfer_matrix1,
                                                     assist_model_probability_transfer_matrix2
                                                     ]

    # =============================================================================================================

    folder_demon_dir = "/data/home/username/Experiments/LLM_ensemble/Datasets/MMLU/dev-jsonl"
    folder_demon_files_list = os.listdir(folder_demon_dir)
    folder_demon_files_list.sort()

    folder_dev_dir = "/data/home/username/Experiments/LLM_ensemble/Datasets/MMLU/val-jsonl"
    folder_dev_files_list = os.listdir(folder_dev_dir)
    folder_dev_files_list.sort()

    folder_test_dir = "/data/home/username/Experiments/LLM_ensemble/Datasets/MMLU/test-jsonl"
    folder_test_files_list = os.listdir(folder_test_dir)
    folder_test_files_list.sort()
    # pdb.set_trace()
    if len(folder_demon_files_list) != len(folder_dev_files_list) != len(folder_test_files_list):
        return

    if run_mode == "dev":
        folder_demon_files_list = folder_demon_files_list
        folder_input_files_list = folder_dev_files_list
        folder_input_dir = folder_dev_dir
    else:
        folder_demon_files_list = folder_demon_files_list[start_index:end_index]
        folder_input_files_list = folder_test_files_list[start_index:end_index]
        folder_input_dir = folder_test_dir

    result_file_path = os.path.join(result_save_dir,
                                    f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.jsonl')
    try:
        with open(result_file_path, 'r') as file:
            lines = file.readlines()
            line_count = len(lines)
        start_index = line_count
    except:
        start_index = 0
    index = 0
    for demon_file, input_file in tqdm(zip(folder_demon_files_list, folder_input_files_list),
                                       total=len(folder_demon_files_list), desc="Processing Files"):

        with open(os.path.join(folder_input_dir, input_file), 'r', encoding='utf-8') as input_file:

            demon_instruction, demon_count = demon_prompt_generate(os.path.join(folder_demon_dir, demon_file),
                                                                   demon_parameter)
            contents = input_file.readlines()
            for line in tqdm(contents):
                index += 1
                if index <= start_index:
                    continue

                line = json.loads(line)

                task_instruction = task_instruction_generate(line, instruction_parameter)
                final_input_prompt = instruction + demon_instruction + task_instruction
                main_model_input = main_model_system_template.format(final_input_prompt)

                information_key_list = demon_parameter['key']
                information_dict = {}
                for key in information_key_list:
                    information_dict[key] = line[key]
                information_dict['main_model_input'] = main_model_input
                information_dict['demon_count'] = demon_count
                information_dict['task_instruction'] = task_instruction
                information_dict['max_new_tokens'] = max_new_tokens
                information_dict['result_process_parameter'] = result_process_parameter
                information_dict['logits_processor_mode'] = logits_processor_mode
                information_dict['anchor_point_list'] = anchor_point_list
                information_dict['forced_eos_token_id'] = 2
                ensemble_model_output_ids_queue = queue.Queue()
                assist_model_score_queue_list = []
                assist_model_score_queue1 = queue.Queue()
                assist_model_score_queue2 = queue.Queue()
                # assist_model_score_queue3 = queue.Queue()

                assist_model_score_queue_list.append(assist_model_score_queue1)
                assist_model_score_queue_list.append(assist_model_score_queue2)
                # assist_model_score_queue_list.append(assist_model_score_queue3)

                main_model_thread = MainModelThread(main_model=main_model,
                                                    main_model_tokenizer=main_model_tokenizer,
                                                    assist_model_tokenizer=assist_model_tokenizer1,
                                                    information_dict=information_dict,
                                                    learning_rate=learning_rate,
                                                    anchor_point_count=anchor_point_count,
                                                    learning_epochs_nums=learning_epochs_nums,
                                                    result_save_dir=result_save_dir,
                                                    ensemble_model_output_ids_queue=ensemble_model_output_ids_queue,
                                                    assist_model_score_queue_list=assist_model_score_queue_list,
                                                    main_model_probability_transfer_matrix_list=main_model_probability_transfer_matrix_list,
                                                    assist_model_probability_transfer_matrix_list=assist_model_probability_transfer_matrix_list,
                                                    device_compute=device_compute,
                                                    device=device0,
                                                    early_stop_string_list=early_stop_string_list
                                                    )

                assist_model_input1 = assist_model1_system_template.format(final_input_prompt)
                assist_model_input2 = assist_model2_system_template.format(final_input_prompt)
                # assist_model_input3 = assist_model3_system_template.format(final_input_prompt)

                assist_model_thread1 = AssistModelThread(model=assist_model1,
                                                         model_tokenizer=assist_model_tokenizer1,
                                                         assist_model_input=assist_model_input1,
                                                         assist_model_score_queue=assist_model_score_queue1,
                                                         device=device1,
                                                         result_save_dir=result_save_dir
                                                         )
                assist_model_thread2 = AssistModelThread(model=assist_model2,
                                                         model_tokenizer=assist_model_tokenizer2,
                                                         assist_model_input=assist_model_input2,
                                                         assist_model_score_queue=assist_model_score_queue2,
                                                         device=device2,
                                                         result_save_dir=result_save_dir
                                                         )
                # assist_model_thread3 = AssistModelThread(model=assist_model3,
                #                                          model_tokenizer=assist_model_tokenizer3,
                #                                          assist_model_input=assist_model_input3,
                #                                          assist_model_score_queue=assist_model_score_queue3,
                #                                          device=device3,
                #                                          result_save_dir=result_save_dir
                #                                          )

                main_model_thread.start()
                assist_model_thread1.start()
                assist_model_thread2.start()
                # assist_model_thread3.start()
                main_model_thread.join()
                assist_model_thread1.join()
                assist_model_thread2.join()
                # assist_model_thread3.join()
    time_elapsed = time.time() - start_time  # 获得时间差
    minutes = int(time_elapsed / 60)
    seconds = int(time_elapsed % 60)
    print('Time taken: {} min {} sec'.format(minutes, seconds))


if __name__ == '__main__':
    main()
