import argparse

def create_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description='CHT-CNN Training Parameter Parser',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Sparsity parameters
    sparsity_group = parser.add_argument_group('Sparsity Parameters')
    sparsity_group.add_argument(
        '--conv_sparsity', 
        type=float,
        help='Conv sparsity [0, 1)'
    )
    sparsity_group.add_argument(
        '--link_update_ratio', 
        type=float, 
        default=0.,
        help='Link update ratio [0, 1)'
    )
    sparsity_group.add_argument(
        '--conv_remove_method', 
        type=str,
        default='wm',
        choices=['rand', 'wm'],
        help='Link removal method'
    )
    sparsity_group.add_argument(
        '--conv_regrow_method', 
        type=str,
        default='rand',
        choices=['rand', 'L3n'],
        help='Link regrowth method'
    )
    sparsity_group.add_argument(
        '--shared_mask_sw', 
        action='store_true',
        help='Whether to use shared mask across sliding windows'
    )
    sparsity_group.add_argument(
        '--shared_mask_zone', 
        action='store_true',
        help='Whether to use shared mask across zones'
    )
    sparsity_group.add_argument(
        '--zone_sz', 
        type=int, 
        default=0,
        help='Zone size for shared mask'
    )
    sparsity_group.add_argument(
        '--avg_remove', 
        action='store_true',
        help='Whether to remove links according to average scores between sliding windows'
    )
    sparsity_group.add_argument(
        '--avg_regrow', 
        action='store_true',
        help='Whether to regrow links according to average scores between sliding windows'
    )
    sparsity_group.add_argument(
        '--soft', 
        action='store_true',
        help='Whether to use soft removal and regrowth, as in CHTs'
    )
    sparsity_group.add_argument(
        '--use_opt4', 
        action='store_true',
        help='Whether to use option 4 in CHT\'s _get_L3n_regrow_pos method'
    )
    sparsity_group.add_argument(
        '--delta', 
        type=float, 
        default=0.5,
        help='Delta parameter for CHT training'
    )
    sparsity_group.add_argument(
        '--delta_max', 
        type=float, 
        default=0.75,
        help='Maximum delta value'
    )
    sparsity_group.add_argument(
        '--delta_d', 
        type=float, 
        default=0.01,
        help='Delta increment step'
    )
    sparsity_group.add_argument(
        '--ch_method',
        type=str,
        default='CH3',
        choices=['CH2', 'CH3'],
        help='Cannistraci-Hebb method variant'
    )
    
    # Task parameters
    task_group = parser.add_argument_group('Task Parameters')
    task_group.add_argument(
        '--architecture', 
        type=str, 
        choices=['VGG-16','ResNet-20'],
        help='Model type'
    )
    task_group.add_argument(
        '--dataset', 
        type=str,
        help='Dataset name'
    )
    task_group.add_argument(
        '--save', 
        action='store_true'
    )
    task_group.add_argument(
        '--grid_search', 
        action='store_true'
    )
    
    # Training parameters
    training_group = parser.add_argument_group('Training Parameters')
    training_group.add_argument(
        '--epochs', 
        type=int, 
        default=30,
        help='Number of training epochs'
    )
    training_group.add_argument(
        '--lr', 
        type=float, 
        default=0.001,
        help='Learning rate'
    )
    training_group.add_argument(
        '--bs', 
        type=int, 
        default=64,
        help='Batch size'
    )
    training_group.add_argument(
        '--use_scheduler', 
        action='store_true',
        help='Whether to use learning rate scheduler in training'
    )

    
    # Experiment parameters
    experiment_group = parser.add_argument_group('Experiment Parameters')
    experiment_group.add_argument(
        '--num_processes', 
        type=int, 
        default=1,
        help='Number of processes for multi-process experiments'
    )
    experiment_group.add_argument(
        '--seed', 
        type=int, 
        default=15,
        help='Random seed'
    )
    experiment_group.add_argument(
        '--device', 
        type=str, 
        help='GPU device'
    )
    experiment_group.add_argument(
        '--dropout', 
        type=float
    )
    experiment_group.add_argument('--one_fc', action='store_true')

    #dst args
    dst_group=parser.add_argument_group('dst scheduler parameters')
    dst_group.add_argument("--linear_sparsity", type=float, default=0.99, help="directly give the sparsity to each layer")
    dst_group.add_argument("--update_interval", type=int, default=1, help="the number of intervals for topology evolution")
    dst_group.add_argument("--zeta", type=float, default=0.3, help="the fraction of removal and regrown links")
    dst_group.add_argument("--adaptive_zeta", action="store_true", help="add this augment to make the zeta reducing across the epochs")
    dst_group.add_argument("--linear_remove_method", type=str, default="weight_magnitude", help="how to remove links, Magnitude or MEST")
    dst_group.add_argument("--linear_regrow_method", type=str, default="random", help="how to regrow new links. "
                                                                            "Including: random, gradient, CH3_L3p_soft, CH3_L3n_soft, CH2_L3n_soft")
    dst_group.add_argument("--init_mode", type=str, default="kaiming", help="how to initialize the weights of the model."
                                                                         "Including: kaiming, swi")
    dst_group.add_argument("--chain_removal", action="store_true", help="use forward removal and backward removal")
    dst_group.add_argument("--self_correlated_sparse", action="store_true")
    dst_group.add_argument("--soft_self_correlated_sparse", action="store_true")
    dst_group.add_argument("--dim", type=int, default=1)
    dst_group.add_argument("--dimension", type=int, default=0)
    dst_group.add_argument("--WS", action="store_true")
    dst_group.add_argument("--DNM", action="store_true")
    dst_group.add_argument("--cross", action="store_true")
    # dst_group.add_argument("--BA", action="store_true")
    dst_group.add_argument("--random_rewiring", type=float, default=1.0)
    dst_group.add_argument("--M", type=int, help="number of dendrites")
    dst_group.add_argument("--M_dist", type=str, help="Distribution of the number of dendrites among neurons", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"], default="fixed")
    dst_group.add_argument("--gamma", type=float, help="dendritic spreading parameter")
    dst_group.add_argument("--gamma_dist", type=str, default = "fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    dst_group.add_argument("--degree_dist", type=str, default = "fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    dst_group.add_argument("--synaptic_dist", type=str, help="distribution of the nodes in the synapses", default="fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    dst_group.add_argument("--BHI", action="store_true", help="Bipartite Hyperbolic Initialisation")
    dst_group.add_argument("--CWS", action="store_true", help="Cannistraci Watts Strogatz")
    dst_group.add_argument("--WS1", action="store_true", help="First variation of WS initialisation")
    dst_group.add_argument("--WS2", action="store_true", help="Second variation of WS initialisation")
    dst_group.add_argument("--WS3", action="store_true", help="Third variation of WS initialisation")    
    #dst_group.add_argument("--delta", type=float, help="δ locality parameter (0 ≤ δ ≤ 1)")
    dst_group.add_argument("--F", type=str, help="distribution of δ over the neurons", default="fixed")
    dst_group.add_argument("--sigma_x", type=float, help="standard deviation over the x")
    dst_group.add_argument("--sigma_y", type=float, help="standard deviation over the y")
    dst_group.add_argument("--rho", type=float, help="correlation")
    dst_group.add_argument("--QHI", action="store_true", help="General Hyperbolic Initialisation")
    dst_group.add_argument("--BHI_T", type=float, default=0.0, help="nPSO temperature. If 0: purely greedy by hyperbolic distance; >0: probabilistic sampling.")
    dst_group.add_argument("--BHI_gamma", type=float, default=2.0, help="Power-law exponent gamma controlling the radial coordinate dynamics in nPSO.") 
    dst_group.add_argument("--BHI_distr", type=int, default=0, help=(    
        "angular coordinate distribution:"
        "  0                uniform on [0,2π) (PSO model)"
        "  C>0              integer # of equidistant GaussianMixture components"
        "  GaussianMixture  sklearn GaussianMixture object for custom GM"
        "  (angles, probs, centers)  3-tuple of lists for fully custom mixture"
        )
    )
    dst_group.add_argument("--rewire_mode", type=str, choices=["none", "uniform", "random"], default="none", help=(
        "Optional bipartite rewiring after nPSO wiring: "
        "'none' = keep original edges; "
        "'uniform' = redistribute B-endpoint degrees as evenly as possible; "
        "'random' = reassign B endpoints at random (preserving A-degrees)."
        )
    )
    dst_group.add_argument("--degree_allocation", action="store_true", help= "Enable the spatial-sorting degree allocation strategy: for each bipartite block, compute each neuron's degree and then reorder them so that the highest-degree neuron sits at the center, the next two occupy the centers of the two halves, the next four the centers of the four quarters, and so on. For the third sandwich only the input side is sorted (the final output layer remains unconstrained).")
    dst_group.add_argument("--no_log", action="store_true")
    dst_group.add_argument("--linearlr", action="store_true")
    dst_group.add_argument('--milestone', type=int, nargs='+', default=[60, 120, 160],
                        help='Decrease learning rate at these epochs.')
    dst_group.add_argument("--iterative_warmup_steps", type=int, default=0)
    dst_group.add_argument("--warmup", action="store_true")
    dst_group.add_argument("--discretelr", action="store_true")
    dst_group.add_argument("--end_factor", type=float, default=0.01)
    dst_group.add_argument("--T_decay", type=str, default="no_decay", choices=["no_decay", "linear"], help="decay the temperature of the sampling")
    dst_group.add_argument("--decay_factor", type=float, default=0.9, help="decay epochs of the learning rate")
    dst_group.add_argument("--itop", action="store_true", help="use the itop logger")
    dst_group.add_argument("--EM_S", action="store_true")
    dst_group.add_argument("--early_stop", action="store_true")
    dst_group.add_argument("--early_stop_thre", type=float, default=0.9)
    dst_group.add_argument("--record_anp", action="store_true")
    dst_group.add_argument("--method", type=str, default="", help="the method name of the experiment")
    dst_group.add_argument("--old_version", action="store_true")
    dst_group.add_argument("--history_weights", action="store_true")
    dst_group.add_argument("--tiedrank", action="store_true")
    dst_group.add_argument("--start_T", type=float, help="set the initial value of delta", default=1.0)
    dst_group.add_argument("--end_T", type=float, help="set the final value of delta", default = 3.0)
    dst_group.add_argument("--factor", type=float, default=0.01)
    dst_group.add_argument("--ssam", action="store_true")
    dst_group.add_argument("--function", type=str, help="varies alpha according to some function", choices=["linear", "sigmoid", "tempering_hard", "tempering_strength", "piecewise_constant_linear", "piecewise_constant_abrupt", "decreasing_step", "fixed_height_step", "wave", "decaying_wave", "cosine_annealing_wr"], default = "linear")
    dst_group.add_argument("--k", type=float, help="constant that regulates the sigmoid functions's steepness", default = 1)
    dst_group.add_argument("--granet", action="store_true")
    dst_group.add_argument("--granet_init_sparsity", type=float, default=0.9)
    dst_group.add_argument("--granet_init_epoch", type=int, default=0)
    dst_group.add_argument("--gmp", action="store_true")
    dst_group.add_argument("--pruning_scheduler", type=str, default="none", help="none, linear, granet, s_shape")
    dst_group.add_argument("--pruning_method", type=str, default="none", help="ri, weight_magnitude, MEST")
    dst_group.add_argument("--sparsity_distribution", type=str, default="uniform", help="uniform, non-uniform")

    
    return parser


def parse_args() :
    parser = create_parser()
    parsed_args = parser.parse_args()
    if "CIFAR" in parsed_args.dataset:
        parsed_args.dim=1
    else: parsed_args.dim=2
    return parsed_args
