import time
import os
import socket
import sys
import wandb
import torch
from runner import Runner


# Default wandb parameters
defaults = dict(
    # System
    run_id        = None,
    computer      = socket.gethostname(),
    fixed_init = None,
    abort_active = None,
    extended_logging = None,
    time_mode = None,
    # Setup
    dataset        = None,
    model          = None,
    nepochs        = None,
    batch_size     = None,
    # Optimizer
    optimizer      = None,
    learning_rate  = None,
    training_warmup_epochs = None, # number of epochs to warmup the lr
    momentum       = None,
    dampening   = None,
    nesterov       = None,
    weight_decay   = None,
    weight_decay_ord = None,
    momentum_warmup_epochs = None,
    lr_warmup_epochs = None,
    restart_steps = None,
    # Sparsifying strategy
    strategy = None,
    # IMP
    n_epochs_finetune = None,
    n_phases = None,    # Requirement: n_epochs_finetune % n_phases == 0
    lr_rewinding = None,
    # GSM
    gsm_desired_sparsity=None,
    # GradualPruning
    pruning_steps=None,
    momentum_warmup=None,
    lr_warmup=None,
    allow_recovering=None,
    use_uniform=None,
    # GREG-1
    delta_wd=None,   # Increase in weight decay
    # RestartedGREG
    follow_schedule=None,
    # Continuous Sparsification (CS)
    lmbd=None,  # Different from weight decay
    s_initial=None,
    beta_final=None,
    # STR
    use_global_threshold=None,
)
debug_mode = False
if '--debug' in sys.argv:
    debug_mode = False
    defaults.update(dict(
    # System
    run_id        = 1,
    computer      = socket.gethostname(),
    fixed_init = True,
    abort_active = False,
    extended_logging = False,
    time_mode = True,
    # Setup
    dataset        = 'mnist',
    model          = 'Simple',
    nepochs        = 10,
    batch_size     = 1024,
    # Optimizer
    optimizer      = 'SGD',
    learning_rate  = '(MultiStepLR, 0.1, [3|7], 0.1)',
    training_warmup_epochs = 3, # number of epochs to warmup the lr
    #learning_rate = '(CosineAnnealingWarmRestarts, 0.1, 0.001, 2)',   # Syntax: (CosineAnnealingWarmRestarts, eta_max, eta_min, epochs_per_restart)
    momentum       = 0.9,
    dampening = None,
    nesterov       = True,
    weight_decay   = 0.0005,
    weight_decay_ord = 2,
    momentum_warmup_epochs = None,
    lr_warmup_epochs=None,
    restart_steps = None,
    # Sparsifying strategy
    strategy = 'GradualPruning',
    # IMP
    n_epochs_finetune = 1,
    n_phases=1,  # Requirement: n_epochs_finetune % n_phases == 0
    lr_rewinding = 'constant',
    # GSM
    gsm_desired_sparsity = 0.9,
    # GradualPruning
    pruning_steps = 2,
    momentum_warmup = False,
    lr_warmup = False,
    allow_recovering = True,
    use_uniform=True,
    # GREG-1
    delta_wd=1e-1,   # Increase in weight decay
    # RestartedGREG
    follow_schedule=True,
    # Continuous Sparsification (CS)/STR/DST
    lmbd = 1e-4, # Different from weight decay
    s_initial=-0.1,
    beta_final=200,
    # STR
    use_global_threshold=False,
))


# configure wandb logging
wandb.init(
    config  = defaults,
    project = 'test-000',   # automatically changed in sweep
    entity  = '', # automatically changed in sweep
    # group   = "000.0",
    # reinit  = True,
)
config = wandb.config
assert config.weight_decay != 0.01, "Ignore this weight decay"
ngpus = torch.cuda.device_count()
if ngpus > 0:
    if ngpus > 1 and config.dataset == 'imagenet':
        config.update(dict(device='cuda:' + ','.join(f"{i}" for i in range(ngpus))))
    else:
        config.update(dict(device = 'cuda:0'))
else:
    config.update(dict(device = 'cpu'))
#os.environ['WANDB_IGNORE_GLOBS'] = '*_state_dict.pt'  # Torch models dont get synced to wandb server
runner = Runner(config=config, debug_mode=debug_mode)
runner.run()

# Close wandb run
wandb.join()
