import argparse
import ast

def parse_list(option_str):

    try:
        result = ast.literal_eval(option_str)
        if not isinstance(result, list):
            raise ValueError
        return result
    except:
        raise argparse.ArgumentTypeError("Invalid list format")

def get_args():

    parser = argparse.ArgumentParser(description="Setup for distributed and model training.")

    # Distributed training
    parser.add_argument("--devices", type=parse_list, default=[0], help="List of device IDs")
    parser.add_argument("--distributed", type=bool, default=False, help="Use DistributedDataParallel (True)")
    parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads")
    parser.add_argument("--method", type=str, default='ReLA', help="Training method to use")

    # Train options
    parser.add_argument("--seed", type=int, default=42, help="Seed for random number generators")
    parser.add_argument("--batch_size", type=int, default=None, help="Batch size for training")
    parser.add_argument("--grad_accu", type=int, default=1, help="Accumulated steps for training")
    parser.add_argument("--input_size", type=int, default=None, help="Image size")
    parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train")
    parser.add_argument("--rounds", type=int, default=1, help="Number of rounds to train")
    parser.add_argument("--dataset", type=str, default='CIFAR10', help="Dataset to use")
    parser.add_argument("--global_ipc", type=int, default=300)
    parser.add_argument("--local_ipc", type=int, default=100)
    parser.add_argument("--align_method", type=str, default='align_to_best', help="Dataset to use")

    # parser.add_argument("--avg_method", type=str, default='no_equal_avg', help="Dataset to use")


    # Model options
    parser.add_argument("--model", type=str, default='resnet18', help="Model to use")
    parser.add_argument("--feature_dim", type=int, default=512, help="Dimension of feature space")

    # Loss options
    parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate")

    # Reload options
    parser.add_argument("--model_path", type=str, default=None, help="Path to save or load the model")

    # Logistic regression options
    parser.add_argument("--logistic_batch_size", type=int, default=128, help="Batch size for logistic regression")
    parser.add_argument("--logistic_epochs", type=int, default=100, help="Number of epochs for logistic regression")

    # SBDD (Self-Balanced Dynamic Decoding) options
    parser.add_argument("--use_sbdd", type=str, default=None, help="Use SBDD or not")
    parser.add_argument("--sbdd_model", type=str, default=None, help="Observer model for SBDD")
    parser.add_argument("--sbdd_input_size", type=int, default=None, help="Observer model input size for SBDD")
    parser.add_argument("--sbdd_ratio", type=float, default=None, help="Data used ratio for SBDD")
    parser.add_argument("--sbdd_static", type=bool, default=False, help="If using static SBDD")
    parser.add_argument("--sbdd_conlam", type=float, default=None, help="If using constant SBDD lambda")
    parser.add_argument("--sbdd_data_path", type=str, default=None, help="Path used to store data for SBDD")
    parser.add_argument("--optimal_data_type", type=str, default='single', help="Type of optimal data")

    args = parser.parse_args()

    args.distributed = len(args.devices) >= 2

    if args.dataset in ["CIFAR10", "CIFAR100"]:
        args.input_size = 32
        args.logistic_batch_size = 128
        if args.batch_size is None:
            args.batch_size = 128

    return args