import argparse
import json
import os
import torch


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run finetuner.")

    # fine tuning a fingerprinted model
    parser.add_argument("--fine_tune", type=str, default=None, choices=["alpaca", "chatdoc"], 
                    help="Fine tune a phi3 fingerprinted model with either Alpaca or ChatDoc.")    
    parser.add_argument("--logging_folder", type=str, required=True, help="logging output folder.")
    parser.add_argument("--model_folder", type=str, required=True, help="model folder.")
    parser.add_argument("--fingerprints", type=str, default=None, help="fingerprint file.")
    args = parser.parse_args()
    
    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")
    return args    

def load_model_config(model_name):
    # load the model config file under config/model_config.json
    model_config_path = os.path.join("config", "model_config.json")
    with open(model_config_path, "r") as f:
        model_config = json.load(f)
    
    if model_name not in model_config:
        raise ValueError("Model name not found in model_config.json")
    return model_config[model_name]


def run_test( model_name, logging_folder, base_model, model_folder, quantize, completion, fingerprints ):
    commandLine = "python test_fingerprint.py"
    
    commandLine += " --model_name " + model_name
    commandLine += " --logging_folder " + logging_folder 
    commandLine += " --model_folder " + model_folder
    if fingerprints != None:
        commandLine += " --fingerprint_file " + fingerprints + "/fingerprint.hf"
    else:
        commandLine += " --fingerprint_file " + base_model + "/fingerprint.hf"
    if completion:
        commandLine += " --completion"
    if quantize:
        commandLine += " --quantized"
    
    result = os.system(commandLine)
    if result != 0:
        return False
    return True

    
def run_finetuning( base_model, model_folder, dataset_file ):
    commandLine = "accelerate"
    commandLine += " launch --config_file /datadrive/huggingface/accelerate/default_config.yaml"
    commandLine += " train.py"
    commandLine += " --model_name_or_path " + base_model
    commandLine += " --data_path " + dataset_file
    commandLine += " --output_dir " + model_folder 
    
    commandLine += " --num_train_epochs 3"
    commandLine += " --per_device_train_batch_size 2"
    commandLine += " --per_device_eval_batch_size 4"
    commandLine += " --gradient_accumulation_steps 8"
    commandLine += " --evaluation_strategy 'no'"
    commandLine += " --save_strategy 'steps'"
    commandLine += " --save_steps 2000"
    commandLine += " --save_total_limit 1"
    commandLine += " --learning_rate 2e-5"
    commandLine += " --weight_decay 0."
    commandLine += " --warmup_ratio 0.03"
    commandLine += " --lr_scheduler_type 'cosine'"
    commandLine += " --logging_steps 1"
    commandLine += " --fp16 True"
    # commandLine += " --tf32 True"
    
    original_directory = os.getcwd()
    os.chdir("finetune")    
    result = os.system(commandLine)
    os.chdir(original_directory)   
    if result != 0:
        return False
    return True   


def run_generate_figures( logging_folder ):
    commandLine = "python create_figures.py"
    commandLine += " --runfolder " + logging_folder
    commandLine += " --figs figs"
    commandLine += " --plot_individual_probs"
    result = os.system(commandLine)
    if result != 0:
        return False
    return True


def main():
    # get the model name
    args = parse_arguments()
    
    # assumes base model is in <model_name>/<config_name>/<run_number>
    base_model = args.model_folder
    model_folder = os.path.dirname(args.model_folder) 
    logging_folder = args.logging_folder + "/" + args.fine_tune
    model_name = os.path.basename(os.path.dirname(model_folder))
    model_config = load_model_config(model_name)
    print(model_config)          
    
    datasets = {
        "alpaca": "alpaca_data.json",
        "chatdoc": "HealthCareMagic-100k.json"
    }
    dataset_file = datasets[args.fine_tune]  
    print(f"\n************* Finetuning model")   
    model_folder = model_folder + "/" + args.fine_tune    
    if os.path.exists(model_folder + "/config.json") == False:
        if run_finetuning( base_model, model_folder, dataset_file ) == False:
            print(f"Error in fine tuning step")
            return
                   
    print(f"\n************* Running tests")
    completion_model = model_config.get("completion", None)   
    os.makedirs(logging_folder, exist_ok=True)      
    if run_test(model_config["model_name"], logging_folder, base_model, model_folder, False, 
                    completion_model != None, args.fingerprints ) == False:
        print(f" Error in test step")
        return
        
    # generate figures
    print(f"\n\n************* Generating figures")        
    if run_generate_figures( logging_folder ) == False: 
        print(f"Error in generate figures step")        
        
if __name__ == "__main__":
    main()        
