import socket
import sys
import wandb
import torch
from iterative_runner import Runner


# Default wandb parameters
defaults = dict(
    # System
    run_id        = None,
    only_cpu = False,  # If true, do not use GPUs
    computer      = socket.gethostname(),
    extended_logging = None,  # Will double up memory requirements
    # Setup
    dataset        = None,
    model          = None,
    batch_size     = None,
    no_dense = None,    # If True, then find a trained model of IMP
    # Sparsifying strategy is always IMP or a variant
    strategy = None,
    n_epochs_per_phase = None,
    n_phases = None, # 1 == OneShot
    lr_rewinding = None,
    gsm_desired_sparsity = None,    # should only be set if we need to redo something
    weight_decay   = None,

)
debug_mode = False
if '--debug' in sys.argv:
    debug_mode = True
    defaults.update(dict(
    # System
    run_id        = 1,
    only_cpu = False,  # If true, do not use GPUs
    computer      = socket.gethostname(),
    extended_logging = False,    # Will double up memory requirements
    # Setup
    dataset        = 'mnist',
    model          = 'Simple',
    batch_size     = 1024,
    no_dense=True,  # If True, then find a trained model of IMP
    # Sparsifying strategy is always IMP or a variant
    strategy = 'IMP',
    n_epochs_per_phase = 2,
    n_phases = None,  # 1 == OneShot
    lr_rewinding = 'lr-rewinding',
    gsm_desired_sparsity = 0.9,
    weight_decay   = 0.0005,
))


# configure wandb logging
wandb.init(
    config  = defaults,
    project = 'test-000',   # automatically changed in sweep
    entity  = '', # automatically changed in sweep
    # group   = "000.0",
    # reinit  = True,
)
config = wandb.config

ngpus = torch.cuda.device_count()
if ngpus > 0 and not config['only_cpu']:
    if ngpus > 1 and config.dataset == 'imagenet':
        config.update(dict(device='cuda:' + ','.join(f"{i}" for i in range(ngpus))))
    else:
        config.update(dict(device = 'cuda:0'))
else:
    config.update(dict(device = 'cpu'))
#os.environ['WANDB_IGNORE_GLOBS'] = '*_state_dict.pt'  # Torch models dont get synced to wandb server
runner = Runner(config=config, debug_mode=debug_mode)
runner.run()

# Close wandb run
wandb.join()
