import argparse
import json
import os
import torch

# dictionary of steps to run
run_steps = {
    "create": True,
    "fingerprint": True,
    "test": True,
    "benchmark": True,
    "benchmark-pretrained": True
}

lm_eval_tasks = "mmlu,winogrande,truthfulqa,hellaswag"

def parse_arguments():
    parser = argparse.ArgumentParser(description="Run SFT fingerprinting.")
    parser.add_argument("--model_name", type=str, help="Model name.")
    parser.add_argument("--logging_folder", type=str, required=True, help="logging output folder.")
    parser.add_argument("--model_folder", type=str, required=True, help="model folder.")
    
    # stages to run
    parser.add_argument("--create", action="store_true", help="Create the model.")
    parser.add_argument("--fingerprint", action="store_true", help="Fingerprint the model.")
    parser.add_argument("--test", action="store_true", help="Test the model.")
    parser.add_argument("--benchmark", action="store_true", help="Benchmark the model.")
    parser.add_argument("--benchmark-pretrained", action="store_true", help="Benchmark the pretrained model.")
    
    # run configuration
    parser.add_argument("--num_runs", type=int, default=1, help="Number of runs.")
    parser.add_argument("--run_config", type=str, help="Run config file.")    
    parser.add_argument("--fingerprint_count", type=int, default=10, help="Number of fingerprints to generate.")
    parser.add_argument("--random", action="store_true", help="Random fingerprint prompts.")
    parser.add_argument("--randpad", action="store_true", help="Random padding.")
    args = parser.parse_args()

    # if any steps are explicitly listed, only those steps are run. First check if any 
    # step is explicitly listed, then update the run_steps dictionary
    if args.create or args.fingerprint or args.test or args.benchmark:
        run_steps["create"] = args.create
        run_steps["fingerprint"] = args.fingerprint
        run_steps["test"] = args.test
        run_steps["benchmark"] = args.benchmark
        run_steps["benchmark-pretrained"] = args.benchmark_pretrained

    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")
    print("Run steps:")
    for key, value in run_steps.items():
        print(f"{key}: {value}")
    print("")
    return args

def load_model_config(model_name):
    # load the model config file under config/model_config.json
    model_config_path = os.path.join("config", "model_config.json")
    with open(model_config_path, "r") as f:
        model_config = json.load(f)
    
    if model_name not in model_config:
        raise ValueError("Model name not found in model_config.json")
    return model_config[model_name]


def run_completion_create_fingerprint( model_name, logging_folder, model_folder, args):
    commandLine = "python completion_fingerprint.py"
    
    commandLine += " --model_name " + model_name
    commandLine += " --fingerprint_folder " + model_folder
    commandLine += " --logging_folder " + logging_folder    
    
    # run the command line and return False if it fails
    if os.system(commandLine) != 0:
        return False
    return True       
    

def run_create_fingerprint(model_name, run_config, logging_folder, model_folder, completion, randomprompts, args):

    # create the command line
    commandLine = "python create_fingerprint.py"
    for key, value in run_config.items():
        commandLine += " --" + key + " " + str(value)
        
    commandLine += f" --fingerprint_count {args.fingerprint_count}"
    commandLine += f" --fingerprint_folder {model_folder}" 
    commandLine += f" --logging_folder {logging_folder}"    
    commandLine += f" --model_name {model_name}"
    if randomprompts:
        commandLine += " --random"
    if args.randpad:
        commandLine += " --randpad"    
    if completion != None:
        commandLine += " --completion"
    
    # run the command line and return False if it fails
    if os.system(commandLine) != 0:
        return False
    return True    


def run_fingerprint(model_name, model_config, logging_folder, model_folder ):
    commandLine = "accelerate launch fingerprint-finetune.py"
    
    # add the model name
    commandLine += " --model_name " + model_name
    commandLine += " --logfile "  + logging_folder + "/fingerprint.log"
    commandLine += " --output_dir " + model_folder
    
    # fingerprint parameters
    commandLine += " --fingerprint_strength 0.9"
    commandLine += " --fingerprint_folder " + model_folder
    
    # training parameters
    commandLine += " --num_train_epochs 200"
    commandLine += " --learning_rate 1e-6"
    per_device_batch_size = model_config.get("per_device_batch_size", 4)
    commandLine += f" --per_device_train_batch_size {per_device_batch_size}"
    gradient_accumulation_steps = model_config.get("gradient_accumulation_steps", 1)
    commandLine += f" --gradient_accumulation_steps {gradient_accumulation_steps}"
    
    # run the command line and return False if it fails
    original_directory = os.getcwd()
    os.chdir("sft")
    result = os.system(commandLine)
    os.chdir(original_directory)
    if result != 0:
        return False
    return True
    
    
def run_test( model_name, logging_folder, model_folder, quantize, completion ):
    commandLine = "python test_fingerprint.py"
    
    commandLine += " --model_name " + model_name
    commandLine += " --logging_folder " + logging_folder 
    commandLine += " --model_folder " + model_folder
    if completion:
        commandLine += " --completion"
    if quantize:
        commandLine += " --quantized"
    
    result = os.system(commandLine)
    if result != 0:
        return False
    return True

def run_benchmark( logging_folder, model_folder, overwrite=True ):
    # get the number of GPUs using torch
    num_gpus = torch.cuda.device_count()
    commandLine = f"accelerate launch --config_file config/accelerate/accelerate_benchmark_config{num_gpus}.yaml -m lm_eval --trust_remote_code"
    commandLine += f" --tasks {lm_eval_tasks}"
    commandLine += f" --model_args pretrained={model_folder},trust_remote_code=True"

    output_file = f"{logging_folder}/benchmark_results.json"
    if os.path.exists(output_file):
        if overwrite:
            os.remove(output_file)
        else:
            return True
    commandLine += f" --output_path={output_file}"

    result = os.system(commandLine)
    if result != 0:
        return False
    return True

def run_generate_figures( logging_folder ):
    commandLine = "python create_figures.py"
    commandLine += " --runfolder " + logging_folder
    commandLine += " --figs figs"
    result = os.system(commandLine)
    if result != 0:
        return False
    return True


def main():
    # get the model name
    args = parse_arguments()
    
    model_config = load_model_config(args.model_name)
    print(model_config)        
    
    # load the json run config file
    with open(args.run_config, "r") as f:
        run_config = json.load(f)

    # extract the base of hte run_config file name
    run_config_base = os.path.basename(args.run_config).split(".")[0]

    # create a folder for the run
    logging_folder = args.logging_folder + "/sft/" + args.model_name + "/" + run_config_base + \
                        f".{args.fingerprint_count}"
    if args.random == True:
        logging_folder += "-random"
    if args.randpad == True:
        logging_folder += "-randpad"      
    os.makedirs(logging_folder, exist_ok=True)
    
    # create folder for model and fingerprints
    model_folder = args.model_folder + "/sft/" + args.model_name + "/" + run_config_base + \
                        f".{args.fingerprint_count}"
    if args.random == True:
        model_folder += "-random"
    if args.randpad == True:
        model_folder += "-randpad"
    os.makedirs(model_folder, exist_ok=True)    

    # is it a completion model?
    completion_model = model_config.get("completion", None)

    # do we need to benchmark the pretrained?
    if run_steps["benchmark-pretrained"]:
        run_logging_folder = args.logging_folder + "/sft/" + args.model_name
        if run_benchmark( run_logging_folder, model_config["model_name"], False ) == False:
            print(f"[{run_number}] Error in benchmark step")
            return

    # run the number of runs        
    for run_number in range(args.num_runs):

        run_logging_folder = logging_folder + "/" + str(run_number)

        # does the model already exists?
        run_model_folder = model_folder + "/" + str(run_number)
        if os.path.exists(run_model_folder + "/config.json") == False:
                    
            print(f"\n\n[{run_number}] ############# Starting run")
            
            # create the output folders
            os.makedirs(run_logging_folder, exist_ok=True)
            os.makedirs(run_model_folder, exist_ok=True)        
            
            # run the steps
            if run_steps["create"]:
                print(f"[{run_number}] ************* Creating fingerprints")
                
                # spawn the create_fingeprint python script
                if run_create_fingerprint(model_config["model_name"], run_config, run_logging_folder, 
                                        run_model_folder, completion_model, args.random, args) == False:
                    print(f"[{run_number}] Error in create step")
                    break
                
                # create target logits with completion model
                if completion_model != None:
                    if run_completion_create_fingerprint(completion_model, run_logging_folder, 
                                                            run_model_folder, args) == False:
                        print(f"[{run_number}] Error in create step")
                        break
                
            # run the fingerprinting step
            if run_steps["fingerprint"]:
                print(f"\n[{run_number}] ************* Fingerprinting model")
                if run_fingerprint(model_config["model_name"], model_config, run_logging_folder, run_model_folder ) == False:
                    print(f"[{run_number}] Error in fingerprint step")
                    break
                
        # run the test step
        if run_steps["test"]:
            print(f"\n[{run_number}] ************* Running tests")
            if run_test(model_config["model_name"], run_logging_folder, run_model_folder, False, completion_model != None ) == False:
                print(f"[{run_number}] Error in test step")
                break
            
            print(f"\n[{run_number}] ************* Running quantized tests")
            if run_test(model_config["model_name"], run_logging_folder, run_model_folder, True, completion_model != None ) == False:
                print(f"[{run_number}] Error in test step")
                break
                
        # run the benchmark step
        if run_steps["benchmark"]:
            print(f"\n[{run_number}] ************* Running benchmarks")
            if run_benchmark( run_logging_folder, run_model_folder ) == False:
                print(f"[{run_number}] Error in benchmark step")
                break
                
    # generate figures
    print(f"\n\n************* Generating figures")
    if run_generate_figures( logging_folder ) == False: 
        print(f"Error in generate figures step")
        
        
if __name__ == "__main__":
    main()



