# %reload_ext autoreload
# %autoreload 2
# %matplotlib inline

from utils import *
from meta import *
from pruning_utils import *
from utils import __WELL_TRAINED__



DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')






print(f'DEVICE={DEVICE}')


optimizee_train = MLP_MNIST
OPT = RPOptimizer

SERVER = 0

which_DataGen = 1
Target_DataGen = [MNISTLoss, Cifar_half][which_DataGen]

args = {
        # Training
        'prune_rounds': 10,
        'prune_amount': 3,
        'n_epochs': [1,100][SERVER],
        'epoch_len': [5,200][SERVER],
        'unroll': [1,20][SERVER],

        'n_tests': [2,10][SERVER],
        'eva_epoch_len': 800,
        'random_scale': 0.,
        'lr': 0.001,
        'Target_DataGen': Target_DataGen,

        'only_want_last': 1,

        # Model
        'OPT': OPT,
        'preproc': [6,16],
        'hidden_layers_zer': [7,17,27],
        'Ngf': 2,
        'beta1': 0.95,
        'beta2': 0.95,
        'pre_tahn_scale': 0,
        'LUM': [LUM, None][1],
        'lum_layers': [len(list(optimizee_train().parameters())), 20, 1],

        # MLP_MNIST optimizee
        'pixels':  [28*28, 3*32*32][which_DataGen],

        }
print(f'\nargs:\n{args}\n\n')


# =========== Debug ===========

# len(list(MLP_MNIST().parameters()))
# optimizee_train().hidden_layers_zee

# l2o_net = OPT().to(DEVICE)


# # =========== Training ===========
# args['Target_Optimizee']= optimizee_train


# l2o_net = RPOptimizer(**args).to(DEVICE)
# if args.get('LUM'): args['lum'] = args['LUM'](args['lum_layers'])
# viz(l2o_net,'l2o')

# if args.get('lum'):
#     params = itertools.chain(l2o_net.parameters(),kwargs['lum'].parameters())
# else:
#     params = l2o_net.parameters()
# if args['pre_tahn_scale']:
#     params = itertools.chain([l2o_net.pre_tahn_scale,], params)
# meta_opt = optim.Adam(params, lr=args['lr']) # tf lr=0.001 b1=0.9 n2=0.999 , the same √

# for ip in range(args['prune_rounds']):
#     l2o_net, lum = train_optimizer(args, l2o_net, meta_opt, **args)
#     pruning_model(l2o_net, args['prune_amount'])
#     print(l2o_net)
#     check_sparsity(l2o_net)




l2o_net = RPOptimizer(**args).to(DEVICE)

original_weight = copy.deepcopy(l2o_net.state_dict())



sd = l2o_net.state_dict()
# print(sd)
print(sd.keys())
print(len(sd.keys()))
viz(l2o_net)


pruning_model(l2o_net, 0.5)
sd = l2o_net.state_dict()
print(sd.keys())
print(len(sd.keys()))

check_sparsity(l2o_net)

def load_state_dict(model, state_dict):
    mask_dict = extract_mask(model.state_dict())
    remove_prune(model)
    model.load_state_dict(state_dict)
    prune_model_custom(model, mask_dict)

load_state_dict(l2o_net, original_weight)



# # =========== Evaluation ===========

# # == load model ==
# l2o_net = OPT().to(DEVICE)

# load_model(l2o_net, cwd)
# if args['LUM']:
#     lum = LUM(args['lum_layers'])
# cwd = WELL_TRAINED[0]

#     load_model(lum, cwd)
#     args['lum'] = lum


# # == Eva model ==
# optimizee_test = MLP_MNIST2
# dic2 = {'l2o_net':l2o_net, 
#         'Target_Optimizee':optimizee_test,
#         }; args.update(dic2)
# eva_l2o_optimizer(args, **args)



























