from multiprocessing import Process
from random import randint
from time import sleep
from queue import Queue
from datetime import date
from math import sqrt
import GPUtil
import os
import argparse


def arg_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--arch', type=str, default='resnet20')
    parser.add_argument('--scope', type=str, default='global')
    parser.add_argument('--device', type=str, default='x3')
    parser.add_argument('--repeat', type=int, default=3)
    parser.add_argument('--compression-type',
                        type=str,
                        default='compression',
                        choices=['compression', 'width'])
    parser.add_argument('--gpus',
                        type=int,
                        nargs='*',
                        default=[0, 1, 2, 3, 4, 5, 6, 7])
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = arg_parse()
    today = date.today().strftime("%b%d")
    gather_result_path = f'{today}_c10_{args.arch}_{args.scope}_{args.compression_type}_new'
    while os.path.exists(gather_result_path):
        gather_result_path += '_'
    print(f'gather result saved at {gather_result_path}\n')

    prune_epochs = 100
    post_epochs = 160
    grids = 8
    repeat = args.repeat
    width_factor = 1
    processes = []
    jobs = Queue()
    for i in range(grids):
        for j in range(repeat):
            compression = 8**(-i / 4)
            print(compression)
            if args.compression_type == 'width':
                width_factor = sqrt(compression)
                compression = 1
            if compression == 1 and args.scope != 'global':
                continue
            expid = gather_result_path + f'x{round(compression, 2)}_id{j}'
            command = f'''
            python main.py \
                --dataset cifar10 \
                --model {args.arch} \
                --optimizer momentum \
                --train-batch-size 128 \
                --post-epochs 160 \
                --lr 0.1 \
                --lr-schedule linear \
                --weight-decay 1e-4 \
                --pruner synflow \
                --compression {compression} \
                --prune-epochs {prune_epochs} \
                --expid {expid} \
                --mask-scope {args.scope} \
                --compression-schedule exponential \
                --gather-result-path {gather_result_path} \
                --width-factor {width_factor} \
                --no-prune-linear \
                --reinitialize \
                --reinitialize-sparse 1 \
                --workers 8 '''

            if 'resnet' not in args.arch:
                command += '--prune-pw-only '
            if args.scope != 'global':
                command += '--no-prune-linear '

            jobs.put(command)

    device_map_x3 = [0, 1, 2, 3, 4, 5, 6, 7]
    dict_map = {
        'x3': device_map_x3,
    }
    device_map = dict_map[args.device]
    count = 0
    while not jobs.empty():
        avai_devices = GPUtil.getAvailable(maxLoad=0.05,
                                           maxMemory=0.05,
                                           limit=4)
        if len(avai_devices) > 0:
            device = None
            for d in avai_devices:
                if d in args.gpus:
                    device = device_map[d]
                    break
            if device is not None:
                count += 1
                print(f'INFO: Lauching {count} jobs in queue')
                command = jobs.get()
                command = f'export CUDA_VISIBLE_DEVICES={device};' + command
                p = Process(target=os.system, args=(command, ))
                p.start()
        sleep(20)
    print(f'gather result saved at {gather_result_path}\n')
