import argparse
from dataclasses import dataclass

def get_args():
    parser = argparse.ArgumentParser()
    #common args
    parser.add_argument('--dataset', type=str, help="including " \
        "CIFAR10, CIFAR100, EMNIST, Fashion_MNIST, MNIST, FER2013, SVHN, tiny-imagenet-ori, tiny-imagenet-crop, tiny-imagenet-resize, OxfordFlowers, Caltech256, OxfordIIITPet, INaturalist.")
    parser.add_argument('--architecture',type=str,help="including 'VGG-16', 'ResNet-20")
    parser.add_argument('--epochs',type=int)
    parser.add_argument('--device', default='cuda:8', type=str, help='cuda or cpu')
    parser.add_argument('--bs', default=64, type=int, help='Batchsize')
    parser.add_argument('--lr', default=0.005, type=float, help='Learning rate') 
    parser.add_argument('--save',action='store_true',default=False,help='whether to save model')
    parser.add_argument('--l', default=4, type=int, help='L')
    parser.add_argument('--t', default=64, type=int, help='T')
    parser.add_argument('--temp', action='store_true')
    parser.add_argument('--fold_BN', action='store_true')
    parser.add_argument('--dropout', type=float)
    parser.add_argument('--one_fc', action = 'store_true')

    #dafault
    parser.add_argument('--dim',type=int)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--init_up',type=float,default= 8. )
    parser.add_argument('--activation_mode', type=str, default='softplus',help='origin,softplus')
    

    #conv args
    parser.add_argument(
        '--conv_sparsity', 
        type=float,
        help='Conv sparsity [0, 1)'
    )
    parser.add_argument(
        '--link_update_ratio', 
        type=float, 
        default=0.,
        help='Link update ratio [0, 1)'
    )
    parser.add_argument(
        '--conv_remove_method', 
        type=str,
        default='wm',
        choices=['rand', 'wm'],
        help='Link removal method'
    )
    parser.add_argument(
        '--conv_regrow_method', 
        type=str,
        default='rand',
        choices=['rand', 'L3n'],
        help='Link regrowth method'
    )
    parser.add_argument(
        '--shared_mask_sw', 
        action='store_true',
        help='Whether to use shared mask across sliding windows'
    )
    parser.add_argument(
        '--shared_mask_zone', 
        action='store_true',
        help='Whether to use shared mask across zones'
    )
    parser.add_argument(
        '--zone_sz', 
        type=int, 
        default=0,
        help='Zone size for shared mask'
    )
    parser.add_argument(
        '--avg_remove', 
        action='store_true',
        help='Whether to remove links according to average scores between sliding windows'
    )
    parser.add_argument(
        '--avg_regrow', 
        action='store_true',
        help='Whether to regrow links according to average scores between sliding windows'
    )
    parser.add_argument(
        '--soft', 
        action='store_true',
        help='Whether to use soft removal and regrowth, as in CHTs'
    )
    parser.add_argument(
        '--use_opt4', 
        action='store_true',
        help='Whether to use option 4 in CHT\'s _get_L3n_regrow_pos method'
    )
    parser.add_argument(
        '--delta', 
        type=float, 
        default=0.5,
        help='Delta parameter for CHT training'
    )
    parser.add_argument(
        '--delta_max', 
        type=float, 
        default=0.75,
        help='Maximum delta value'
    )
    parser.add_argument(
        '--delta_d', 
        type=float, 
        default=0.01,
        help='Delta increment step'
    )
    parser.add_argument(
        '--ch_method',
        type=str,
        default='CH3',
        choices=['CH2', 'CH3'],
        help='Cannistraci-Hebb method variant'
    )
    parser.add_argument(
        '--use_hidden',
        action='store_true',
        help='Whether to use hidden mask in shared_mask_sw + CHT'
    )

    #dst_args
    parser.add_argument("--linear_sparsity", type=float,  help="directly give the sparsity to each layer")
    parser.add_argument("--update_interval", type=int, default=1, help="the number of intervals for topology evolution")
    parser.add_argument("--zeta", type=float, default=0.3, help="the fraction of removal and regrown links")
    parser.add_argument("--adaptive_zeta", action="store_true", help="add this augment to make the zeta reducing across the epochs")
    parser.add_argument("--linear_remove_method", type=str, default="weight_magnitude", help="how to remove links, Magnitude or MEST")
    parser.add_argument("--linear_regrow_method", type=str, default="random", help="how to regrow new links. "
                                                                            "Including: random, gradient, CH3_L3p_soft, CH3_L3n_soft, CH2_L3n_soft")
    parser.add_argument("--init_mode", type=str, default="kaiming", help="how to initialize the weights of the model."
                                                                         "Including: kaiming, swi")
    parser.add_argument("--chain_removal", action="store_true", help="use forward removal and backward removal")
    parser.add_argument("--self_correlated_sparse", action="store_true")
    parser.add_argument("--soft_self_correlated_sparse", action="store_true")
    parser.add_argument("--dimension", type=int, default=0)
    parser.add_argument("--WS", action="store_true")
    parser.add_argument("--DNM", action="store_true")
    parser.add_argument("--cross", action="store_true")
    # parser.add_argument("--BA", action="store_true")
    parser.add_argument("--random_rewiring", type=float, default=1.0)
    parser.add_argument("--M", type=int, help="number of dendrites")
    parser.add_argument("--M_dist", type=str, help="Distribution of the number of dendrites among neurons", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"], default="fixed")
    parser.add_argument("--gamma", type=float, help="dendritic spreading parameter")
    parser.add_argument("--gamma_dist", type=str, default = "fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    parser.add_argument("--degree_dist", type=str, default = "fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    parser.add_argument("--synaptic_dist", type=str, help="distribution of the nodes in the synapses", default="fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    parser.add_argument("--BHI", action="store_true", help="Bipartite Hyperbolic Initialisation")
    parser.add_argument("--CWS", action="store_true", help="Cannistraci Watts Strogatz")
    parser.add_argument("--WS1", action="store_true", help="First variation of WS initialisation")
    parser.add_argument("--WS2", action="store_true", help="Second variation of WS initialisation")
    parser.add_argument("--WS3", action="store_true", help="Third variation of WS initialisation")    
    #parser.add_argument("--delta", type=float, help="δ locality parameter (0 ≤ δ ≤ 1)")
    parser.add_argument("--F", type=str, help="distribution of δ over the neurons", default="fixed")
    parser.add_argument("--sigma_x", type=float, help="standard deviation over the x")
    parser.add_argument("--sigma_y", type=float, help="standard deviation over the y")
    parser.add_argument("--rho", type=float, help="correlation")
    parser.add_argument("--QHI", action="store_true", help="General Hyperbolic Initialisation")
    parser.add_argument("--BHI_T", type=float, default=0.0, help="nPSO temperature. If 0: purely greedy by hyperbolic distance; >0: probabilistic sampling.")
    parser.add_argument("--BHI_gamma", type=float, default=2.0, help="Power-law exponent gamma controlling the radial coordinate dynamics in nPSO.") 
    parser.add_argument("--BHI_distr", type=int, default=0, help=(    
        "angular coordinate distribution:"
        "  0                uniform on [0,2π) (PSO model)"
        "  C>0              integer # of equidistant GaussianMixture components"
        "  GaussianMixture  sklearn GaussianMixture object for custom GM"
        "  (angles, probs, centers)  3-tuple of lists for fully custom mixture"
        )
    )
    parser.add_argument("--rewire_mode", type=str, choices=["none", "uniform", "random"], default="none", help=(
        "Optional bipartite rewiring after nPSO wiring: "
        "'none' = keep original edges; "
        "'uniform' = redistribute B-endpoint degrees as evenly as possible; "
        "'random' = reassign B endpoints at random (preserving A-degrees)."
        )
    )
    parser.add_argument("--degree_allocation", action="store_true", help= "Enable the spatial-sorting degree allocation strategy: for each bipartite block, compute each neuron's degree and then reorder them so that the highest-degree neuron sits at the center, the next two occupy the centers of the two halves, the next four the centers of the four quarters, and so on. For the third sandwich only the input side is sorted (the final output layer remains unconstrained).")
    parser.add_argument("--no_log", action="store_true")
    parser.add_argument("--linearlr", action="store_true")
    parser.add_argument('--milestone', type=int, nargs='+', default=[60, 120, 160],
                        help='Decrease learning rate at these epochs.')
    parser.add_argument("--iterative_warmup_steps", type=int, default=0)
    parser.add_argument("--warmup", action="store_true")
    parser.add_argument("--discretelr", action="store_true")
    parser.add_argument("--end_factor", type=float, default=0.01)
    parser.add_argument("--T_decay", type=str, default="no_decay", choices=["no_decay", "linear"], help="decay the temperature of the sampling")
    parser.add_argument("--decay_factor", type=float, default=0.9, help="decay epochs of the learning rate")
    parser.add_argument("--itop", action="store_true", help="use the itop logger")
    parser.add_argument("--EM_S", action="store_true")
    parser.add_argument("--early_stop", action="store_true")
    parser.add_argument("--early_stop_thre", type=float, default=0.9)
    parser.add_argument("--record_anp", action="store_true")
    parser.add_argument("--method", type=str, default="", help="the method name of the experiment")
    parser.add_argument("--old_version", action="store_true")
    parser.add_argument("--history_weights", action="store_true")
    parser.add_argument("--tiedrank", action="store_true")
    parser.add_argument("--start_T", type=float, help="set the initial value of delta", default=1.0)
    parser.add_argument("--end_T", type=float, help="set the final value of delta", default = 3.0)
    parser.add_argument("--factor", type=float, default=0.01)
    parser.add_argument("--ssam", action="store_true")
    parser.add_argument("--function", type=str, help="varies alpha according to some function", choices=["linear", "sigmoid", "tempering_hard", "tempering_strength", "piecewise_constant_linear", "piecewise_constant_abrupt", "decreasing_step", "fixed_height_step", "wave", "decaying_wave", "cosine_annealing_wr"], default = "linear")
    parser.add_argument("--k", type=float, help="constant that regulates the sigmoid functions's steepness", default = 1)
    parser.add_argument("--granet", action="store_true")
    parser.add_argument("--granet_init_sparsity", type=float, default=0.9)
    parser.add_argument("--granet_init_epoch", type=int, default=0)
    parser.add_argument("--gmp", action="store_true")
    parser.add_argument("--pruning_scheduler", type=str, default="none", help="none, linear, granet, s_shape")
    parser.add_argument("--pruning_method", type=str, default="none", help="ri, weight_magnitude, MEST")
    parser.add_argument("--sparsity_distribution", type=str, default="uniform", help="uniform, non-uniform")

    

    args = parser.parse_args()

    if "CIFAR" in args.dataset:
        args.dim=1
    else: args.dim=2

    return args

@dataclass
class Args:
    # Sparsity parameters
    sparsity: float             # Required
    mlp_sparsity: float         # MLP sparsity parameter
    link_update_ratio: float
    remove_method: str
    regrow_method: str
    shared_mask_sw: bool
    shared_mask_zone: bool
    zone_sz: int
    avg_remove: bool            # Whether to remove links according to average scores between SWs: 1.3
    avg_regrow: bool            # Whether to regrow links according to average scores between SWs: 1.2, 1.3
    soft: bool                  # Whether to use soft removal and regrowth, as in CHTs
    use_opt4: bool             # Whether to use option 4 in CHT's _get_L3n_regrow_pos method
    delta: float                # Delta parameter for CHT training
    delta_max: float            # Maximum delta value
    delta_d: float              # Delta increment step
    ch_method: str              # Cannistraci-Hebb method variant
    use_hidden: bool            # Whether to use hidden layer functionality
    
    # Task parameters
    model_type: str             # Required
    dataset: str                # Required
    
    # Training parameters
    num_epochs: int
    learning_rate: float
    batch_size: int
    
    # Experiment parameters
    num_processes: int
    seed: int
    
    # GPU parameters
    gpus: str            # GPU devices string like "0123", default to all GPUs
    
    # Checkpoint parameters
    checkpoint: str  # Path to checkpoint file for resuming training
    tag: str  # Tag to use as prefix for checkpoint folder name