import argparse
import json
import os
import torch

# dictionary of steps to run
run_steps = {
    "create": False,
    "finetune-lora": False,
    "fingerprint": False,
    "test": False,
    "eval_loss": False,
    "benchmark": False,
    "benchmark-pretrained": False
}

lm_eval_tasks = "mmlu,winogrande,truthfulqa,hellaswag"

def parse_arguments():
    parser = argparse.ArgumentParser(description="Run SFT lora fingerprinting.")
    parser.add_argument("--model_name", type=str, help="Model name.")
    parser.add_argument("--logging_folder", type=str, required=True, help="logging output folder.")
    parser.add_argument("--model_folder", type=str, required=True, help="model folder.")
    
    # stages to run
    parser.add_argument("--create", action="store_true", help="Create the model.")
    parser.add_argument("--finetune_lora", action="store_true", help="Fine tune the model with lora adapter.")
    parser.add_argument("--fingerprint", action="store_true", help="Fingerprint the lora adapter.")
    parser.add_argument("--test", action="store_true", help="Test the fingerprited lora adapter.")
    parser.add_argument("--eval_loss", action="store_true", help="Evaluate the loss difference.")
    parser.add_argument("--benchmark", action="store_true", help="Benchmark the model with lora adapter.")
    parser.add_argument("--benchmark_pretrained", action="store_true", help="Benchmark the pretrained model.")
    
    # run configuration
    parser.add_argument("--num_runs", type=int, default=3, help="Number of runs.")
    parser.add_argument("--run_config", type=str, help="Run config file.")    
    parser.add_argument("--fingerprint_count", type=int, default=10, help="Number of fingerprints to generate.")
    parser.add_argument("--random", action="store_true", help="Random fingerprint prompts.")
    parser.add_argument("--randpad", action="store_true", help="Random padding.")
    parser.add_argument("--finetune_data", type=str, default=None, choices=["alpaca", "chatdoc"], 
                    help="Fine tune a model with either Alpaca or ChatDoc.")    
    args = parser.parse_args()

    # if any steps are explicitly listed, only those steps are run. First check if any 
    # step is explicitly listed, then update the run_steps dictionary
    if args.create or args.fingerprint or args.test or args.benchmark or args.finetune_lora or args.benchmark_pretrained or args.eval_loss:
        run_steps["create"] = args.create
        run_steps["finetune-lora"] = args.finetune_lora
        run_steps["fingerprint"] = args.fingerprint
        run_steps["test"] = args.test
        run_steps["eval_loss"] = args.eval_loss
        run_steps["benchmark"] = args.benchmark
        run_steps["benchmark-pretrained"] = args.benchmark_pretrained

    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")
    return args

def load_model_config(model_name):
    # load the model config file under config/model_config.json
    model_config_path = os.path.join("config", "model_config.json")
    with open(model_config_path, "r") as f:
        model_config = json.load(f)
    
    if model_name not in model_config:
        raise ValueError("Model name not found in model_config.json")
    return model_config[model_name]


def run_completion_create_fingerprint( model_name, logging_folder, model_folder, args):
    commandLine = "python completion_fingerprint.py"
    
    commandLine += " --model_name " + model_name
    commandLine += " --fingerprint_folder " + model_folder
    commandLine += " --logging_folder " + logging_folder    
    
    # run the command line and return False if it fails
    if os.system(commandLine) != 0:
        return False
    return True       
    

def run_create_fingerprint(model_name, run_config, logging_folder, model_folder, completion, randomprompts, args):

    # create the command line
    commandLine = "python create_fingerprint.py"
    for key, value in run_config.items():
        commandLine += " --" + key + " " + str(value)
        
    commandLine += f" --fingerprint_count {args.fingerprint_count}"
    commandLine += f" --fingerprint_folder {model_folder}" 
    commandLine += f" --logging_folder {logging_folder}"    
    commandLine += f" --model_name {model_name}"
    if randomprompts:
        commandLine += " --random"
    if args.randpad:
        commandLine += " --randpad"    
    if completion != None:
        commandLine += " --completion"
    
    # run the command line and return False if it fails
    if os.system(commandLine) != 0:
        return False
    return True    

def run_finetune_lora(model_name, model_config, logging_folder, model_folder, data_path):
    commandLine = "accelerate launch finetune-lora.py"
    commandLine += " --model_name_or_path " + model_name
    commandLine += " --output_dir " + model_folder
    commandLine += " --data_path " + data_path
    commandLine += " --num_train_epochs 3"
    commandLine += " --lora_r 8"
    commandLine += " --lora_alpha 16"
    commandLine += " --learning_rate 1e-3"
    commandLine += " --warmup_steps 500"
    commandLine += " --evaluation_strategy 'no'"
    commandLine += " --save_strategy 'no'"
    commandLine += " --report 'none'"
    commandLine += " --save_total_limit 3"
    commandLine += " --fp16 True"
    commandLine += " --lr_scheduler_type 'cosine'"
    commandLine += " --remove_unused_columns False"
    per_device_batch_size = model_config.get("per_device_batch_size", 4)
    commandLine += f" --per_device_train_batch_size {per_device_batch_size}"
    gradient_accumulation_steps = model_config.get("gradient_accumulation_steps", 1)
    commandLine += f" --gradient_accumulation_steps {gradient_accumulation_steps}"

    original_directory = os.getcwd()
    os.chdir("lora")
    result = os.system(commandLine)
    os.chdir(original_directory)
    if result != 0:
        return False
    return True

def run_fingerprint(model_name, model_config, logging_folder, model_folder, lora_adapter_path ):
    commandLine = "accelerate launch fingerprint-lora-continue.py"
    
    # add the model name
    commandLine += " --base_model " + model_name
    commandLine += " --model_folder " + model_folder
    commandLine += " --logging_dir " + "../" + logging_folder
    commandLine += " --output_dir " + model_folder

    if lora_adapter_path != None:
        commandLine += " --finetune_lora_adapter " + lora_adapter_path
    
    # fingerprint parameters
    commandLine += " --fingerprint_strength 0.9"
    
    # training parameters
    per_device_batch_size = model_config.get("per_device_batch_size", 32)
    commandLine += f" --per_device_train_batch_size {per_device_batch_size}"
    gradient_accumulation_steps = model_config.get("gradient_accumulation_steps", 1)
    commandLine += f" --gradient_accumulation_steps {gradient_accumulation_steps}"
    
    # run the command line and return False if it fails
    original_directory = os.getcwd()
    os.chdir("lora")
    result = os.system(commandLine)
    os.chdir(original_directory)
    if result != 0:
        return False
    return True

def run_eval_loss_diff( model_name, logging_folder, model_folder, fingerprint_adapter, finetune_adapter, eval_dataset ):
    commandLine = "accelerate launch --config_file ../config/accelerate/accelerate_eval_config8.yaml"
    commandLine += " eval_loss_diff.py"
    
    commandLine += " --model_name " + model_name
    commandLine += " --logging_dir " + "../" + logging_folder
    commandLine += " --model_folder " + model_folder
    commandLine += " --fingerprint_adapter " + fingerprint_adapter
    commandLine += " --finetune_adapter " + finetune_adapter
    commandLine += " --eval_dataset " + eval_dataset
    
    original_directory = os.getcwd()
    os.chdir("lora")
    result = os.system(commandLine)
    os.chdir(original_directory)
    if result != 0:
        return False
    return True
    
def run_test( model_name, logging_folder, model_folder, lora_adapter, quantize, completion ):
    commandLine = "python test_fingerprint.py"
    
    commandLine += " --model_name " + model_name
    commandLine += " --logging_folder " + logging_folder 
    commandLine += " --model_folder " + model_folder
    commandLine += " --fingerprint_model_adapter " + model_folder + "/fingerprint-lora"
    if lora_adapter:
        commandLine += " --lora_adapter " + lora_adapter

    if completion:
        commandLine += " --completion"
    if quantize:
        commandLine += " --quantized"
    
    result = os.system(commandLine)
    if result != 0:
        return False
    return True

def run_benchmark( logging_folder, model_folder, overwrite=True ):
    # get the number of GPUs using torch
    num_gpus = torch.cuda.device_count()
    commandLine = f"accelerate launch --config_file config/accelerate/accelerate_benchmark_config{num_gpus}.yaml -m lm_eval --trust_remote_code"
    commandLine += f" --tasks {lm_eval_tasks}"
    commandLine += f" --model_args pretrained={model_folder},trust_remote_code=True"

    output_file = f"{logging_folder}/benchmark_results.json"
    if os.path.exists(output_file):
        if overwrite:
            os.remove(output_file)
        else:
            return True
    commandLine += f" --output_path={output_file}"

    result = os.system(commandLine)
    if result != 0:
        return False
    return True

def run_generate_figures( logging_folder ):
    commandLine = "python create_figures.py"
    commandLine += " --runfolder " + logging_folder
    commandLine += " --figs figs-lora"
    commandLine += " --plot_individual_probs"
    result = os.system(commandLine)
    if result != 0:
        return False
    return True


def main():
    # get the model name
    args = parse_arguments()

    print(run_steps)
    
    model_config = load_model_config(args.model_name)
    print(model_config)        
    
    # load the json run config file
    with open(args.run_config, "r") as f:
        run_config = json.load(f)

    # extract the base of hte run_config file name
    run_config_base = os.path.basename(args.run_config).split(".")[0]

    # create a folder for the run
    logging_folder = args.logging_folder + "/lora/" + args.model_name + "/" + run_config_base + \
                        f".{args.fingerprint_count}"
    if args.random == True:
        logging_folder += "-random"
    if args.randpad == True:
        logging_folder += "-randpad"      
    os.makedirs(logging_folder, exist_ok=True)
    
    # create folder for model and fingerprints
    model_folder = args.model_folder + "/lora/" + args.model_name + "/" + run_config_base + \
                        f".{args.fingerprint_count}"
    if args.random == True:
        model_folder += "-random"
    if args.randpad == True:
        model_folder += "-randpad"
    os.makedirs(model_folder, exist_ok=True)    

    # is it a completion model?
    completion_model = model_config.get("completion", None)

    # do we need to benchmark the pretrained?
    if run_steps["benchmark-pretrained"]:
        run_logging_folder = args.logging_folder + "/lora/" + args.model_name
        if run_benchmark( run_logging_folder, model_config["model_name"], False ) == False:
            print(f"Error in pretrained benchmark step")
            return

    # run the number of runs        
    for run_number in range(args.num_runs):

        run_logging_folder = logging_folder + "/" + str(run_number)

        # does the model already exists?
        run_model_folder = model_folder + "/" + str(run_number)
        if os.path.exists(run_model_folder + "/adapter_config.json") == False:
                    
            print(f"\n\n[{run_number}] ############# Starting run")
            
            # create the output folders
            os.makedirs(run_logging_folder, exist_ok=True)
            os.makedirs(run_model_folder, exist_ok=True)        
            
            # run the steps
            if run_steps["create"]:
                print(f"[{run_number}] ************* Creating fingerprints")
                
                # spawn the create_fingeprint python script
                if run_create_fingerprint(model_config["model_name"], run_config, run_logging_folder, 
                                        run_model_folder, completion_model, args.random, args) == False:
                    print(f"[{run_number}] Error in create step")
                    break
                
                # create target logits with completion model
                if completion_model != None:
                    if run_completion_create_fingerprint(completion_model, run_logging_folder, 
                                                            run_model_folder, args) == False:
                        print(f"[{run_number}] Error in create step")
                        break
            
            lora_adapter_path = None
            if run_steps["finetune-lora"]:
                print(f"\n[{run_number}] ************* Fine-tuning model adapter")
                # use existing fine-tuned adapter if it exists
                if os.path.exists(model_folder + "/finetune-lora"):
                    print(f"Using existing fine-tuned adapter")
                    lora_adapter_path = model_folder + "/finetune-lora"
                else:
                    if run_finetune_lora(model_config["model_name"], model_config, run_logging_folder, model_folder, args.finetune_data) == False:
                        print(f"[{run_number}] Error in finetune step")
                        break
                    else:
                        print(f"Completed fine-tuning")
                        lora_adapter_path = model_folder + "/finetune-lora"

            # run the fingerprinting step
            if run_steps["fingerprint"]:
                print(f"\n[{run_number}] ************* Fingerprinting model adapter")
                if run_fingerprint(model_config["model_name"], model_config, run_logging_folder, run_model_folder, lora_adapter_path ) == False:
                    print(f"[{run_number}] Error in fingerprint step")
                    break
                
        # run the test step
        if run_steps["test"]:
            print(f"\n[{run_number}] ************* Running tests")
            lora_adapter_path = model_folder + "/finetune-lora"
            if run_test(model_config["model_name"], run_logging_folder, run_model_folder, lora_adapter_path, False, completion_model != None ) == False:
                print(f"[{run_number}] Error in test step")
                break
        
        if run_steps["eval_loss"]:
            print(f"\n[{run_number}] ************* Evaluating loss difference")
            # check if finetune-lora exists under the model folder
            if os.path.exists(model_folder + "/finetune-lora") == False:
                print(f"[{run_number}] ******old format - finetune adapter was created with each run")
                if run_eval_loss_diff(model_config["model_name"], run_logging_folder, run_model_folder, \
                                run_model_folder + "/fingerprint-lora", run_model_folder + "/finetune-lora", \
                                run_model_folder + "/finetune_eval_dataset" ) == False:
                    print(f"[{run_number}] Error in eval loss step")
                    break
            else:
                if run_eval_loss_diff(model_config["model_name"], run_logging_folder, run_model_folder, \
                                    run_model_folder + "/fingerprint-lora", model_folder + "/finetune-lora", \
                                    model_folder + "/finetune_eval_dataset" ) == False:
                    print(f"[{run_number}] Error in eval loss step")
                    break

        # run the benchmark step
        if run_steps["benchmark"]:
            print(f"\n[{run_number}] ************* Running benchmarks")
            if run_benchmark( run_logging_folder, run_model_folder ) == False:
                print(f"[{run_number}] Error in benchmark step")
                break
        
    # generate figures
    print(f"\n\n************* Generating figures")
    if run_generate_figures( logging_folder ) == False: 
        print(f"Error in generate figures step")
        
        
if __name__ == "__main__":
    main()



