import os

from argparse import ArgumentParser
from train.train_params import params_chord_cond, params_chord_lsh_cond
from train.train_config import LdmTrainConfig

def init_parser():
    parser = ArgumentParser(description='train (or resume training) a diffusion model')
    parser.add_argument(
        "--output_dir",
        default='results',
        help='directory in which to store model checkpoints and training logs'
    )
    parser.add_argument('--uniform_pitch_shift', action='store_true',
                        help="whether to apply pitch shift uniformly (as opposed to randomly)")
    parser.add_argument('--debug', action='store_true', help="whether to use debug mode")
    parser.add_argument('--load_chkpt_from', default=None, help="whether to load existing checkpoint")
    parser.add_argument('--null_cond_weight', default=0.5, help="weight parameter for null condition in classifier free guidance")
    parser.add_argument('--with_melody', action='store_true', help="whether to use melody condition")

    return parser


def args_setting_to_fn(args):
    def to_str(x: bool, char):
        return char if x else ''

    debug = to_str(args.debug, 'd')
    with_melody = to_str(args.debug, 'm')

    return f"model-{with_melody}-{debug}"


if __name__ == "__main__":
    parser = init_parser()
    args = parser.parse_args()

    # Determine random pitch augmentation
    random_pitch_aug = not args.uniform_pitch_shift

    # Generate the filename based on argument settings
    fn = args_setting_to_fn(args)

    # Set the output directory
    output_dir = os.path.join(args.output_dir, fn)

    # Create the training configuration
    if args.with_melody:
        config = LdmTrainConfig(params_chord_lsh_cond, output_dir, debug_mode=args.debug, load_chkpt_from=args.load_chkpt_from)
    else:
        config = LdmTrainConfig(params_chord_cond, output_dir, debug_mode=args.debug, load_chkpt_from=args.load_chkpt_from)

    config.train(null_rhythm_prob=args.null_cond_weight,with_lsh=args.with_melody)