import argparse

# parse_trak_args
def parse_trak_args():

    parser = argparse.ArgumentParser()

    # These parameters are directly inherited from music-transformer-pytorch:parse_generate_args
    # model_weights is replaced by trak_model_checkpoints_root since we now need a checkpoint list for the calculation
    parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with")
    parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to")
    parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.")
    parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
    parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be")
    parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with")
    # parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
    parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy")
    parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
    parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
    parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
    parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
    parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
    parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")

    # Here is additional paramters for TRAK
    parser.add_argument("-train_windows_size", type=int, default=512,
                        help="The length of each training data point.")
    parser.add_argument("-trak_save_dir", type=str, default="./trak_results_saved_all_512_fixed_prod",
                        help="The place to load/save intermediate result for TRAK.")
    parser.add_argument("--trak_training_sanity_check", action="store_true",
                        help="calculate TRAK data attribution score on training set for sanity check")
    parser.add_argument("-trak_model_checkpoints_root", type=str,
                        default="../MusicTransformer-pytorch-rpr-checkpoints",
                        help="The checkpoint root for TRAK attribution score calculation (multiple for remove randomness)")
    parser.add_argument("-trak_model_checkpoints_number", type=int,
                        default=1, help="How many checkpoint to use for TRAK, index from 0.")
    parser.add_argument("-trak_model_checkpoints_extension", type=str,
                        default="results/best_acc_weights.pickle", help="Checkpoint extension, typically this means which specific checkpoint to use after a whole training process.")
    parser.add_argument("-trak_train_batchsize", type=int, default=4, help="training batchsize")
    parser.add_argument("-trak_test_batchsize", type=int, default=4, help="test batchsize")
    parser.add_argument("-trak_test_number", type=int, default=-1,
                        help="How many data sample to calculate the data attribution score, negative number for the whole set.")
    parser.add_argument("--trak_question", action="store_true", help="A blind test for data attribute")
    parser.add_argument("--trak_generated_music_root", type=str,
                        default="../MusicTransformer-pytorch-generated/test-fixed-crop-512-bestacc-processed",
                        help="the root save generated music")
    parser.add_argument("--test_windows_size", type=int, default=512, help="The length of each test data point.")
    parser.add_argument("--trak_only_score", action="store_true", help="only score but not finalize")
    parser.add_argument("--trak_only_gather", action="store_true", help="only finalize")
    parser.add_argument("--trak_cherrypick", type=int, default=-1, help="trak_cherrypick")
    parser.add_argument("-generate_length", type=int, default=0, help="training portion seed")
    parser.add_argument("--ensemble", type=int, default=1, help="ensemble number")
    parser.add_argument("--independent", type=int, default=1, help="independent number")
    parser.add_argument("--dropout", action="store_true", help="dropout or not")

    parser.add_argument("--enable_new_ver", action="store_true", help="enable nn.Linear in rpr implementation")
    parser.add_argument("--LoRA_finetune", action="store_true", help="Load a base model and use LoRA finetuning")
    parser.add_argument("--LoRA_grad_only", action="store_true", help="if yes, then only considr LoRA layers' grad during TRAK")
    parser.add_argument("-multi_LoRA", type=int, default=0, help="using multiple LoRA finetuned models per ensemble")
    return parser.parse_args()

# print_trak_args
def print_trak_args(args):
    for attr, value in vars(args).items():
        print(f'{attr} = {value}')
