#!/usr/bin/env python3

import os
import argparse
import sys

def generate_lr_range(lr_min_exp, lr_max_exp, lr_base):
    
    lr_exponents = range(lr_min_exp, lr_max_exp)
    learning_rates = [lr_base**exp for exp in lr_exponents]
    return learning_rates


def generate_commands(lr_range, epochs=400000, out_dir="~/work/experiments/tmp", 
                     loss="binary_cross_entropy", num_images_list=None, model="linear", 
                     device="cpu", seed=123, classes_list=None, dataset="cifar10",
                     amp=False, channels_last=False, compile_model=False, data_dir=None, lr_base=2, no_normalize=False, num_threads=1, resize_size=None, kernel_size=2, shuffle_features=False, scale_model=None):
    
    assert not os.path.exists(out_dir), f"Output directory {out_dir} already exists"
    os.makedirs(out_dir)
    with open(os.path.join(out_dir, 'command.txt'), 'w') as f:
        f.write(' '.join(sys.argv) + '\n')
    
    if isinstance(lr_range, tuple) and len(lr_range) == 2:
        learning_rates = generate_lr_range(lr_range[0], lr_range[1], lr_base)
    elif isinstance(lr_range, list):
        learning_rates = lr_range
    else:
        raise ValueError("lr_range must be a tuple (min_exp, max_exp) or a list of learning rates")
    
    # Set defaults for num_images_list and classes_list
    if num_images_list is None:
        num_images_list = [None]  # Use all images by default
    if classes_list is None:
        classes_list = ['0,1']  # Default binary classification
    
    current_dir = os.getcwd()
    command_counter = 0
    
    # Generate commands for all combinations
    for lr_idx, lr in enumerate(learning_rates):
        for num_images in num_images_list:
            for classes in classes_list:
                # Build command parts
                base_parts = [
                    "python train_gd.py",
                    f"--epochs {epochs}",
                    f"--lr {lr}",
                    "--out-dir {out_dir}",
                    f"--loss {loss}",
                    f"--model {model}",
                    f"--device {device}",
                    f"--seed {seed}",
                    f"--classes {classes}",
                    f"--dataset {dataset}",
                ]
                
                if amp:
                    base_parts.append("--amp")
                if channels_last:
                    base_parts.append("--channels-last")
                if compile_model:
                    base_parts.append("--compile")
                if data_dir:
                    base_parts.append(f"--data-dir {data_dir}")
                if num_images is not None:
                    base_parts.append(f"--num-images {num_images}")
                if no_normalize:
                    base_parts.append("--no-normalize")
                if num_threads:
                    base_parts.append(f"--num-threads {num_threads}")
                if resize_size:
                    base_parts.append(f"--resize-size {resize_size}")
                if kernel_size:
                    base_parts.append(f"--kernel-size {kernel_size}")
                if shuffle_features:
                    base_parts.append("--shuffle-features")
                if scale_model is not None:
                    base_parts.append(f"--scale-model {scale_model}")
                # Create output directory for this specific combination
                out_dir_combo = os.path.join(out_dir, str(command_counter))
                command = " ".join(base_parts).format(out_dir=out_dir_combo)
                
                # Write command to file
                filename = f"singularity_{command_counter}.sh"
                with open(os.path.join(out_dir, filename), 'w') as f:
                    f.write(f"cd {current_dir}\n")
                    f.write(command + '\n')
                
                command_counter += 1

    print(out_dir)


def main():
    parser = argparse.ArgumentParser(description='Generate training commands with different learning rates')
    
    parser.add_argument('--lr-range', type=int, nargs="+", required=True,
                       help='Learning rate range as "min_exp,max_exp" (e.g., "-8,7" for 2^-8 to 2^7) or comma-separated list of learning rates')
    parser.add_argument('--lr-base', type=int, default=2)
    
    # Training parameters
    parser.add_argument('--epochs', type=int, default=400000, help='Number of epochs')
    parser.add_argument('--out-dir', type=str, help='Output directory')
    parser.add_argument('--loss', type=str, default='binary_cross_entropy', help='Loss function')
    parser.add_argument('--num-images', type=str, nargs='*', default=None, help='Number of images to use (can specify multiple, use "none" for all images)')
    parser.add_argument('--model', type=str, default='linear', help='Model to use')
    parser.add_argument('--device', type=str, default='cpu', help='Device to use')
    parser.add_argument('--seed', type=int, default=123, help='Random seed')
    parser.add_argument('--data-dir', type=str, help='Data directory')
    parser.add_argument('--classes', type=str, nargs='*', default=['0,1'], help='Class pairs to use (format: "class1,class2", can specify multiple)')
    parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use')
    parser.add_argument('--no-normalize', action='store_true', help='Do not normalize the data')
    parser.add_argument('--num-threads', type=int, default=1, help='Number of threads to use')
    parser.add_argument('--resize-size', type=str, default=None, help='Resize size')
    parser.add_argument('--kernel-size', type=int, default=2, help='Kernel size')
    parser.add_argument('--shuffle-features', action='store_true', help='Shuffle features')
    parser.add_argument('--scale-model', type=float, default=None)

    
    parser.add_argument('--amp', action='store_true', help='Enable mixed precision')
    parser.add_argument('--channels-last', action='store_true', help='Use channels_last memory format')
    parser.add_argument('--compile', action='store_true', help='Use torch.compile')
    
    args = parser.parse_args()
    
    lr_range_str = args.lr_range
    assert len(lr_range_str) == 2, "lr_range must be a list of two integers"
    min_exp = lr_range_str[0]
    max_exp = lr_range_str[1]
    lr_range = (min_exp, max_exp)
    
    num_images_list = args.num_images
    if num_images_list is not None:
        num_images_list = [None if str(x).lower() == 'none' else int(x) for x in num_images_list]
    
    classes_list = args.classes
    
    generate_commands(
        lr_range=lr_range,
        epochs=args.epochs,
        out_dir=args.out_dir,
        loss=args.loss,
        num_images_list=num_images_list,
        model=args.model,
        device=args.device,
        seed=args.seed,
        data_dir=args.data_dir,
        classes_list=classes_list,
        dataset=args.dataset,
        amp=args.amp,
        channels_last=args.channels_last,
        compile_model=args.compile,
        lr_base=args.lr_base,
        no_normalize=args.no_normalize,
        num_threads=args.num_threads,
        resize_size=args.resize_size,
        kernel_size=args.kernel_size,
        shuffle_features=args.shuffle_features,
        scale_model=args.scale_model,
    )

if __name__ == "__main__":
    main()
