import argparse
import os
from multiprocessing import Pool

import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())

parser = argparse.ArgumentParser(
    description='parameterised slurm runs')

parser.add_argument('--epochs', default=200, type=int, help='epochs')
parser.add_argument('--run_sh', default=None, type=str, help='a sh file')
args = parser.parse_args()

if args.run_sh is None:
    raise ValueError('no run_sh file provided')
cmd = f'sh {args.run_sh}'
print(cmd)
os.system(cmd)
# default= 
#     f'CUDA_VISIBLE_DEVICES=0,1 python faug/main.py -a gpu -n 2 -b 256 -w 16 -m {args.epochs} --save resnet50aug-blockdrop-cifar10-sgd.1 --model resnet50_aug --lr 0.1 --optimizer sgd --feature_aug_config ./configs/cifar10/mid_prob_mid_block.yaml --feature_aug config -db milestone --dataset cifar10 --seed 1',
#     f'CUDA_VISIBLE_DEVICES=2,3 python faug/main.py -a gpu -n 2 -b 256 -w 16 -m {args.epochs} --save resnet50aug-blockdrop-cifar10-sgd.2 --model resnet50_aug --lr 0.1 --optimizer sgd --feature_aug_config ./configs/cifar10/mid_prob_mid_block.yaml --feature_aug config -db milestone --dataset cifar10 --seed 2']


# def run_p(cmd):
#     os.system(cmd)

# with Pool(processes=2) as pool:
#     pool.map(run_p, cmds)

# 1. Train your roberta normally
# 2. Call prune function in Roberta to produce weights.ckpt and layer_info.yaml
# 3. Finetune on Roberta_fleixble by loading weights.ckpt and layer_info.yaml

class Roberta():
    ...

    def prune(self):
        # determine which head is more important
        # how many heads i have per layer
        layers = {
            'layer1': 3,
            'layer2': 4,
        }
        #save the right weights
        weights = self.manipulate(weights)
        self.save('pruned_weights.ckpt', weights)
        save_as_yaml(layers)
        # yaml content 
        # toml
        # layer1: 3
        # layer2: 4

