import socket
import sys
import os
import shutil
import torch
import wandb
from runners.scratchRunner import scratchRunner
from runners.pretrainedRunner import pretrainedRunner
from strategies import scratchStrategies
# Default wandb parameters
defaults = dict(
    # System
    run_id=None,
    computer=socket.gethostname(),
    collect_class_statistics=False,
    # Setup
    dataset=None,
    arch=None,
    n_epochs=None,
    batch_size=None,
    # Effiency
    use_amp=True,
    # Optimizer
    optimizer=None,
    learning_rate=None,
    n_epochs_warmup=None,  # number of epochs to warmup the lr, should be an int
    momentum=None,
    weight_decay=None,
    decouple_wd=None,
    # Sparsifying strategy
    strategy=None,
    use_pretrained=None,
    goal_sparsity=None,
    pruning_selector=None,  # must be in ['global', 'uniform', 'random', 'LAMP']
    # Retraining
    n_phases=None,  # Should be 1, except when using IMP
    n_epochs_per_phase=None,
    n_epochs_to_split=None,
    retrain_schedule=None,
    retrain_schedule_warmup=None,
    retrain_schedule_init=None,
    retrain_wd=None,
    retrain_wd_schedule=None,
    retrain_adaptive_in_every_cycle=None,
    # GMP
    pruning_interval=None,
    # STR
    s_initial=None,
    # DST
    penalty=None,
    # CS
    beta_final=None,
)

if '--debug' in sys.argv:
    defaults.update(dict(
        # System
        run_id=1,
        computer=socket.gethostname(),
        collect_class_statistics=False,
        # Setup
        dataset='mnist',
        arch='Simple',
        n_epochs=None,
        batch_size=1028,
        # Effiency
        use_amp=True,
        # Optimizer
        optimizer='SGD',
        learning_rate='(Linear, 0.1, 0.0001)',
        n_epochs_warmup=None,  # number of epochs to warmup the lr, should be an int
        momentum=0.9,
        weight_decay=0.0001,
        decouple_wd=True,
        # Sparsifying strategy
        strategy='IMP',
        use_pretrained='Dense',
        goal_sparsity=0.8,
        pruning_selector='global',  # must be in ['global', 'uniform', 'random', 'LAMP']
        # Retraining
        n_phases=2,  # Should be 1, except when using IMP
        n_epochs_per_phase=1,
        n_epochs_to_split=None,
        retrain_schedule='ALLR',
        retrain_schedule_warmup=None,
        retrain_schedule_init=None,
        retrain_wd=0.0005,
        retrain_wd_schedule='(InitialOnly, 0.5)',
        retrain_adaptive_in_every_cycle=True,
        # GMP
        pruning_interval=1,
        # STR
        s_initial=-0.1,
        # DST
        penalty=0.002,
        # CS
        beta_final=300,
    ))

# Configure wandb logging
wandb.init(
    config=defaults,
    project='test-000',  # automatically changed in sweep
    entity=None,  # automatically changed in sweep
)
config = wandb.config
ngpus = torch.cuda.device_count()
if ngpus > 0:
    config.update(dict(device='cuda:0'))
else:
    config.update(dict(device='cpu'))

# At the moment, IMP is the only strategy that requires a pretrained model, all others start from scratch
if config.use_pretrained is not None:
    # Use the pretrainedRunner
    runner = pretrainedRunner(config=config)
else:
    # Use the scratchRunner
    try:
        check_for_strategy_existence = getattr(scratchStrategies, config.strategy)
    except Exception as e:
        raise NotImplementedError("Strategy does not exist, potentially forgot to specify 'use_pretrained'.")
    runner = scratchRunner(config=config)
runner.run()

# Close wandb run
wandb_dir_path = wandb.run.dir
wandb.join()

# Delete the local files
if os.path.exists(wandb_dir_path):
    shutil.rmtree(wandb_dir_path)
# Delete temporary directory
if os.path.exists(runner.tmp_dir):
    shutil.rmtree(runner.tmp_dir)
