# llmHMER/QwenHMER/main.py

import json
import os
import shutil
import subprocess
from datetime import datetime
import yaml

from cdm_evaluator import cdm_proc_folder, plot_cdm_metrics, plot_exp_metrics
from config_utils import (
    check_existing_checkpoints,
    execute_command,
    read_yaml_config,
    run_command,
)
from dataset_utils import (
    all_caption_dict,
    all_test_id,
    check_logs_exist_from_sh,
    crohme_count_dict,
    get_all_captions,
    get_all_crohme_captions,
    get_all_ids,
    read_all_hme100k_captions,
    sort_checkpoint_folders,
)
from inference_runner import aggregate_inference_results, split_predictions_by_dataset
from tqdm import tqdm
from train_inference_generator import make_llmfac_inference_eachepoch_file


def llamaHMERmain(
    train_file="/home/user/workspaces/llmHMER/LLaMA-Factory/examples/train_full/qwen2vl2b_full_sft_crohme-0102_noboxed_white.yaml",  # default train file
    inference_file="/home/user/workspaces/llmHMER/LLaMA-Factory/examples/inference/default/qwen2_vl2B_full_crohme_baseline.yaml",  # default inference file
    folder_name="example_predict_0117",  # default folder name

    llama_factory_root="/home/user/workspaces/llama-factory",  # default LlamaFactory root directory
    model_name="qwen2_vl-2b",  # default model name
    crohme_base_dir="/home/user/workspaces/HMER_Dataset/crohme-rgb-white/HMER/CROHME",  # default CROHME dataset path
    hme100k_base_dir="/home/user/workspaces/HMER_Dataset/hme100k",  # default HME100K dataset path
    use_cdm=False,
    just_inference=False,
    eval_datasets=[],
    eval_batch_size=32,
    skip_load_data=False,
    use_vllm=True,
    dont_inference=False,
    config_file=None,  # can use config file to override all parameters
    vllm_batch_size=32768,
    only_eval_last_ckp=False,
):
    """
    LLAMA HMER 主函数，用于训练和推理
    
    Args:
        train_file: train config file path
        inference_file: inference config file path
        folder_name: output folder name
        llama_factory_root: LlamaFactory root directory
        model_name: model name
        crohme_base_dir: CROHME dataset path
        hme100k_base_dir: HME100K dataset path
        use_cdm: whether to use CDM post-processing
        just_inference: whether to only perform inference
        eval_datasets: evaluation dataset list
        eval_batch_size: evaluation batch size
        skip_load_data: whether to skip loading data
        use_vllm: whether to use VLLM for inference
        dont_inference: whether to not perform inference
        config_file: config file path, can override all parameters
    """
    
    # if config file is provided, load parameters from config file
    if config_file and os.path.exists(config_file):
        import yaml
        with open(config_file, 'r') as f:
            config = yaml.safe_load(f)
            
        # update parameters
        locals_dict = locals()
        for key, value in config.items():
            if key in locals_dict:
                locals_dict[key] = value
                
        # reassign to local variables
        train_file = locals_dict.get('train_file', train_file)
        inference_file = locals_dict.get('inference_file', inference_file)
        folder_name = locals_dict.get('folder_name', folder_name)
        llama_factory_root = locals_dict.get('llama_factory_root', llama_factory_root)
        model_name = locals_dict.get('model_name', model_name)
        crohme_base_dir = locals_dict.get('crohme_base_dir', crohme_base_dir)
        hme100k_base_dir = locals_dict.get('hme100k_base_dir', hme100k_base_dir)
        use_cdm = locals_dict.get('use_cdm', use_cdm)
        just_inference = locals_dict.get('just_inference', just_inference)
        eval_datasets = locals_dict.get('eval_datasets', eval_datasets)
        eval_batch_size = locals_dict.get('eval_batch_size', eval_batch_size)
        skip_load_data = locals_dict.get('skip_load_data', skip_load_data)
        use_vllm = locals_dict.get('use_vllm', use_vllm)
        dont_inference = locals_dict.get('dont_inference', dont_inference)
        only_eval_last_ckp = locals_dict.get('only_eval_last_ckp', only_eval_last_ckp)
    # check if two files exist
    if not os.path.exists(train_file):
        raise FileNotFoundError(f"Train file {train_file} not found.")
    if not os.path.exists(inference_file) and len(eval_datasets) == 0:
        raise FileNotFoundError(f"Inference file {inference_file} not found.")
    # read output_dir from train_file
    config = read_yaml_config(train_file)
    print(config)
    train_epoch = int(config["num_train_epochs"])
    output_dir = os.path.join(llama_factory_root, config.get("output_dir", ""))

    all_captions_dic = all_caption_dict

    # check if training results already exist
    if check_existing_checkpoints(output_dir,required_checkpoints=train_epoch) or just_inference:
        print(f"Training results already exist in {output_dir}. Skipping training...")
    else:
        # run training
        print(f"Changing directory to {llama_factory_root}")
        os.chdir(llama_factory_root)
        # train_command_prefix = "DISABLE_VERSION_CHECK=1 "
        train_command_prefix = "FORCE_TORCHRUN=1 "
        train_command = f"llamafactory-cli train {train_file}"
        train_command = train_command_prefix + train_command
        print(f"Running training with command: {train_command}")
        execute_command(train_command)

    # generate inference script for each epoch
    

    # evaluate CROHME
    if dont_inference:
        print("Dont inference")
    else:
        inference_script_path = make_llmfac_inference_eachepoch_file(
            train_file=train_file,
            inference_file=inference_file,
            folder_name=folder_name,
            inference_type="full",  # can be "lora"
            llama_factory_root=llama_factory_root,
            model_name=model_name,
            eval_datasets=eval_datasets,
            eval_batch_size=eval_batch_size,
            use_vllm=use_vllm,
            batch_size=vllm_batch_size,
            only_eval_last_ckp=only_eval_last_ckp,
        )
        if check_logs_exist_from_sh(inference_script_path) and not just_inference:
            print("All logs exist and the execution is complete.")
        else:
            print("Some log files are missing, execution may not be complete.")
            inference_command = f"bash {inference_script_path}"
            print(f"Running inference with command: {inference_command}")
            os.chdir(llama_factory_root)
            # exit(0)
            execute_command(inference_command)

            # copy train file to predict folder
            shutil.copy(train_file,os.path.join(llama_factory_root, f"saves/{model_name}/full/sft/predict/{folder_name}"))
            all_test_id_dic = all_test_id
            # crohme_captions_all += read_all_hme100k_captions(hme100k_base_dir)
            # crohme_captions_all = {**crohme_captions_all, **read_all_hme100k_captions(hme100k_base_dir)}



    eval_datasets = eval_datasets

    predict_parent_folder = os.path.join(
        llama_factory_root, f"saves/{model_name}/full/sft/predict/{folder_name}"
    )
    os.makedirs(predict_parent_folder, exist_ok=True)


    checkpoint_dirs = [
        d for d in os.listdir(predict_parent_folder)
        if d.startswith("checkpoint") and os.path.isdir(os.path.join(predict_parent_folder, d))
    ]
    checkpoint_dirs = sort_checkpoint_folders(checkpoint_dirs)
    if only_eval_last_ckp:
        checkpoint_dirs = checkpoint_dirs[-1:]


    for ckp_dir in tqdm(checkpoint_dirs, desc="Splitting Predictions by Dataset"):
        ckp_path = os.path.join(predict_parent_folder, ckp_dir)
        prediction_file = os.path.join(ckp_path, 'generated_predictions.jsonl')
        if not os.path.isfile(prediction_file):
            print(f"Skipping {ckp_dir}, no generated_predictions.jsonl found.")
            continue
        try:

            split_predictions_by_dataset(
                prediction_file=prediction_file,
                eval_datasets=eval_datasets,
                crohme_count_dict=crohme_count_dict,
                output_dir=ckp_path
            )
        except Exception as e:
            print(e)
            break

    exp_summary_list = []
    if os.path.exists(os.path.join(predict_parent_folder,"exp_summary_list.json")):
        before_exp_summary_list = json.load(open(os.path.join(predict_parent_folder,"exp_summary_list.json"),'r'))
    else:
        before_exp_summary_list = []
    cdm_summary_list = []
    if os.path.exists(os.path.join(predict_parent_folder,"cdm_summary_list.json")):
        before_cdm_summary_list = json.load(open(os.path.join(predict_parent_folder,"cdm_summary_list.json"),'r'))
    else:
        before_cdm_summary_list = []
    for dataset in eval_datasets:

        print(f"Processing dataset: {dataset}")
        output_folder = os.path.join(predict_parent_folder, f"results_{dataset}")
        os.makedirs(output_folder, exist_ok=True)

        aggregated_results = []
        for ckp_dir in checkpoint_dirs:
            ckp_path = os.path.join(predict_parent_folder, ckp_dir)
            print(f"Aggregating results for {ckp_dir}...")
            split_file = os.path.join(ckp_path, f"{dataset}_predictions.jsonl")
            if not os.path.isfile(split_file):
                print(f"Split file {split_file} not found. Skipping this checkpoint for {dataset}.")
                continue
            epochckp_output_folder = os.path.join(output_folder, ckp_dir)
            os.makedirs(epochckp_output_folder, exist_ok=True)

            exp_stats = aggregate_inference_results(
                dataset_split_file=split_file,
                output_folder=epochckp_output_folder,
                dataset_name=dataset,
                image_id_list=all_test_id_dic[dataset],
                gt_captions_dict=all_captions_dic[dataset],
                crohme_count_dict=crohme_count_dict
            )

            exp_summary_list.append(exp_stats)
            
            
           

            if use_cdm:
                cdm_output_dir = os.path.join(os.pardir,'cdm_output',folder_name,ckp_dir+'-'+dataset)
                cdm_metrics_file_path = os.path.join(epochckp_output_folder, f"cdm_eval_{dataset}")
                print(f"Running CDM evaluation for {dataset} on {ckp_dir}...")
                cdm_stats = cdm_proc_folder(input_folder=epochckp_output_folder, output_dir=cdm_output_dir, dataset_name=dataset,cp_output_path=cdm_metrics_file_path)
                cdm_summary_list.extend(cdm_stats)

        print(f"Finished processing dataset: {dataset}")
    
    for before_cdm_summary in before_cdm_summary_list:
        ok = False
        for cdm_summary in cdm_summary_list:
            if before_cdm_summary['dataset_name'] == cdm_summary['dataset_name'] and before_cdm_summary['checkpoint_step'] == cdm_summary['checkpoint_step']:
                ok = True
                break
        if not ok:
            cdm_summary_list.append(before_cdm_summary)
    
    
    
    existing_entries = {(summary['dataset_name'], summary['checkpoint_step']): True 
                        for summary in exp_summary_list}
    
    for before_exp_summary in before_exp_summary_list:
        key = (before_exp_summary['dataset_name'], before_exp_summary['checkpoint_step'])
        if key not in existing_entries:
            exp_summary_list.append(before_exp_summary)
    
    today_str = datetime.now().strftime("%m%d")
    backup_dir = os.path.join(predict_parent_folder, "backup")
    os.makedirs(backup_dir, exist_ok=True)
    
    def backup_file(filename):
        src_path = os.path.join(predict_parent_folder, filename)
        if os.path.exists(src_path):
            dst_path = os.path.join(backup_dir, f"{today_str}_{filename}")
            shutil.copy(src_path, dst_path)
    
    backup_file("exp_summary_list.json")
    backup_file("cdm_summary_list.json")
    
    def save_json(filename, data):
        with open(os.path.join(predict_parent_folder, filename), 'w') as file:
            json.dump(data, file, indent=4)
    
    save_json("exp_summary_list.json", exp_summary_list)
    save_json("cdm_summary_list.json", cdm_summary_list)
    
    engine_prefix = "vllm" if use_vllm else "huggingface"
    if len(exp_summary_list) > 0:
        shutil.copy(
            os.path.join(predict_parent_folder, "exp_summary_list.json"),
            os.path.join(predict_parent_folder, f"{engine_prefix}_exp_summary_list.json")
        )
    if len(cdm_summary_list) > 0:
        shutil.copy(
            os.path.join(predict_parent_folder, "cdm_summary_list.json"),
            os.path.join(predict_parent_folder, f"{engine_prefix}_cdm_summary_list.json")
        )

    plot_exp_metrics(
        exp_summary_list,
        output_dir=predict_parent_folder,
        plot_filename='exp_metrics_crohme.png'
    )
    if use_cdm:
        print("Eval summary list (CDM):", cdm_summary_list)
        plot_cdm_metrics(
            evaluation_summary=cdm_summary_list,
            output_dir=predict_parent_folder,
            plot_filename='cdm_metrics_crohme.png'
        )

    print("\nAll processes completed successfully.")
    
    

def llamaHMER_run_twice(train_file="/home/user/workspaces/llmHMER/LLaMA-Factory/examples/train_full/qwen2vl2b_full_sft_crohme-0102_noboxed_white.yaml", 
    inference_file="/home/user/workspaces/llmHMER/LLaMA-Factory/examples/inference/default/qwen2_vl2B_full_crohme_baseline.yaml", 
    folder_name="example_predict_0117", 

    llama_factory_root="/home/user/workspaces/llama-factory", 
    model_name="qwen2_vl-2b", 
    crohme_base_dir="/home/user/workspaces/HMER_Dataset/crohme-rgb-white/HMER/CROHME", 
    hme100k_base_dir="/home/user/workspaces/HMER_Dataset/hme100k", 
    use_cdm=False,
    just_inference=False,
    eval_datasets=[],
    eval_batch_size=32,
    skip_load_data=False,
    use_vllm=True,
    dont_inference=False,
    config_file=None,  
    vllm_batch_size=65536,
):
    import os
    import shutil
    train2_file = train_file.replace(".yaml", "-2.yaml")
    train2_config = read_yaml_config(train_file)
    train2_config['output_dir'] = read_yaml_config(train_file)['output_dir']+"-2"
    # shutil.copy(train_file,train2_file)
    folder_name2 = folder_name+"-2"
    with open(train2_file, 'w') as f:
        yaml.dump(train2_config, f, default_flow_style=False)
    
    llamaHMERmain(
        train_file=train_file,
        inference_file=inference_file,
        folder_name=folder_name,
        llama_factory_root=llama_factory_root,
        model_name=model_name,
        crohme_base_dir=crohme_base_dir,
        hme100k_base_dir=hme100k_base_dir,
        use_cdm=use_cdm,
        just_inference=just_inference,
        eval_datasets=eval_datasets,
        eval_batch_size=eval_batch_size,
        skip_load_data=skip_load_data,
        use_vllm=use_vllm,
        dont_inference=dont_inference,
    )
    
    llamaHMERmain(
        train_file=train2_file,
        inference_file=inference_file,
        folder_name=folder_name2,
        llama_factory_root=llama_factory_root,
        model_name=model_name,
        crohme_base_dir=crohme_base_dir,
        hme100k_base_dir=hme100k_base_dir,
        use_cdm=use_cdm,
        just_inference=just_inference,
        eval_datasets=eval_datasets,
        eval_batch_size=eval_batch_size,
        skip_load_data=skip_load_data,
        use_vllm=use_vllm,
        dont_inference=dont_inference,
    )

if __name__ == "__main__":
    example_config = {
        "train_file": "/home/user/workspaces/llmHMER/LLaMA-Factory/examples/train_full/qwen2vl2b_full_sft_crohme-0102_noboxed_white.yaml",
        "inference_file": "/home/user/workspaces/llmHMER/LLaMA-Factory/examples/inference/default/qwen2_vl2B_full_crohme_baseline.yaml",
        "folder_name": "example_predict_config",
        "model_name": "qwen2_vl-2b",
        "use_cdm": True,
        "use_vllm": True,
        "eval_datasets": ["CROHME2014_test", "CROHME2016_test", "CROHME2019_test"]
    }
    
    os.makedirs("configs", exist_ok=True)
    

    
    # standard call
    llamaHMERmain(
        train_file="/home/user/workspaces/llmHMER/LLaMA-Factory/examples/train_full/qwen2vl2b_full_sft_crohme-0102_noboxed_white.yaml",
        inference_file="/home/user/workspaces/llmHMER/LLaMA-Factory/examples/inference/default/qwen2_vl2B_full_crohme_baseline.yaml",
        folder_name="example_predict",
        use_cdm=True,
        use_vllm=True,
    )
    

    
    
    
    