import tensorflow as tf
import numpy as np
import callbacks
import pruning
import data
import os
import models
import argparse
from effective_masks import *
from utils import *
import logging

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
parser=argparse.ArgumentParser()
parser.add_argument('--sample',type=str,default='0',help='seed code')
parser.add_argument('--path_to_data',type=str,help='path to tinyimagenet folder')
parser.add_argument('--save',type=int,default=1,help='whether to save output files (choose one of: 0, 1)')
parser.add_argument('--architecture',type=str,default='lenet300100',help='network type (choose one of: lenet300100, lenet5, vgg16, vgg19, resnet18)')
parser.add_argument('--dataset',type=str,default='mnist',help='dataset to use if not default (choose one of: mnist, fmnist, kmnist for lenets so far)')
parser.add_argument('--pruner',type=str,default='snip',help='pruner (choose one of: dense, lamp, snip, snip/iterative, synflow, random/uniform, random/erk, random/igq, random/uniform_plus, random/synflow, magnitude/global, magnitude/uniform, magnitude/erk, magnitude/igq, magnitude/uniform_plus)')
parser.add_argument('--com_exp',type=float,default=None,help='target compression = 10 ** com_exp (overwrites --target_sparsity)')
parser.add_argument('--target_sparsity',type=float,default=0.9,help='target sparsity (overwritten by --com_exp if given)')
parser.add_argument('--pruning_type',type=str,default='direct',help='choose one of: direct, effective')
parser.add_argument('--train',type=int,default=1,help='whether to train a subnetwork (choose one of: 0, 1)')
parser.add_argument('--out_path',type=str,default='EffectiveSparsity',help='path to directory for outputs')
args=parser.parse_args()
args.target_sparsity=0 if args.pruner=='dense' else args.target_sparsity
args.pruning_type='' if args.pruner=='dense' else args.pruning_type

lenet300100_mnist_config={'data':'mnist','lr':0.15,'batch_size_train':256,'iterations':17579,'weight_decay':0.0005,'batchnorm':False,'momentum':0.9,'lr_decay':[25000,50000,75000,100000],'batch_size_snip':256} # cf. Lee et al., 2018
lenet300100_fmnist_config={'data':'fmnist','lr':0.15,'batch_size_train':256,'iterations':17579,'weight_decay':0.0005,'batchnorm':False,'momentum':0.9,'lr_decay':[25000,50000,75000,100000],'batch_size_snip':256} # cf. Lee et al., 2018
lenet300100_kmnist_config={'data':'kmnist','lr':0.15,'batch_size_train':256,'iterations':15821,'weight_decay':0.0005,'batchnorm':False,'momentum':0.9,'lr_decay':[25000,50000,75000,100000],'batch_size_snip':256} # cf. Lee et al., 2018
lenet5_mnist_config={'data':'mnist','lr':0.15,'batch_size_train':256,'iterations':17579,'weight_decay':0.0005,'batchnorm':True,'momentum':0.9,'lr_decay':[30000,60000,90000,120000],'batch_size_snip':256} # cf. Lee et al., 2018
lenet5_fmnist_config={'data':'fmnist','lr':0.15,'batch_size_train':256,'iterations':17579,'weight_decay':0.0005,'batchnorm':True,'momentum':0.9,'lr_decay':[30000,60000,90000,120000],'batch_size_snip':256} # cf. Lee et al., 2018
lenet5_kmnist_config={'data':'kmnist','lr':0.15,'batch_size_train':256,'iterations':15821,'weight_decay':0.0005,'batchnorm':True,'momentum':0.9,'lr_decay':[30000,60000,90000,120000],'batch_size_snip':256} ## cf. Lee et al., 2018
lenet5_cifar10_config={'data':'cifar10','lr':0.15,'batch_size_train':256,'iterations':23438,'weight_decay':0.0005,'batchnorm':True,'momentum':0.9,'lr_decay':[30000,60000,90000,120000],'batch_size_snip':256}
vgg16_cifar10_config={'data':'cifar10','lr':0.1,'batch_size_train':128,'iterations':62500,'weight_decay':0.000,'batchnorm':True,'momentum':0.9,'lr_decay':[31250,46875],'batch_size_snip':128} # cf. Frankle et al., 2020
vgg16_cifar100_config={'data':'cifar100','lr':0.1,'batch_size_train':128,'iterations':62500,'weight_decay':0.000,'batchnorm':True,'momentum':0.9,'lr_decay':[31250,46875],'batch_size_snip':128} #
vgg19_config={'data':'cifar100','lr':0.1,'batch_size_train':128,'iterations':62500,'weight_decay':0.0001,'batchnorm':True,'momentum':0.9,'lr_decay':[31250,46875],'batch_size_snip':1280} # cf. Wang et al., 2020
resnet18_tinyimagenet_config={'data':'tinyimagenet','lr':0.2,'batch_size_train':256,'iterations':78200,'weight_decay':0.0001,'batchnorm':True,'momentum':0.9,'lr_decay':[39100,58650],'batch_size_snip':2560} # cf. Frankle et al., 2020
resnet18_cifar10_config={'data':'cifar10','lr':0.2,'batch_size_train':256,'iterations':39075,'weight_decay':0.000,'batchnorm':True,'momentum':0.9,'lr_decay':[39100,58650],'batch_size_snip':128} # 
resnet18_cifar100_config={'data':'cifar10','lr':0.2,'batch_size_train':256,'iterations':39075,'weight_decay':0.000,'batchnorm':True,'momentum':0.9,'lr_decay':[39100,58650],'batch_size_snip':128} #


if args.architecture=='lenet300100' and args.dataset=='mnist':
  config=lenet300100_mnist_config
if args.architecture=='lenet300100' and args.dataset=='fmnist':
  config=lenet300100_fmnist_config
if args.architecture=='lenet300100' and args.dataset=='kmnist':
  config=lenet300100_kmnist_config
if args.architecture=='lenet5' and args.dataset=='mnist':
  config=lenet5_mnist_config
if args.architecture=='lenet5' and args.dataset=='fmnist':
  config=lenet5_fmnist_config
if args.architecture=='lenet5' and args.dataset=='kmnist':
  config=lenet5_kmnist_config
if args.architecture=='lenet5' and args.dataset=='cifar10':
  config=lenet5_cifar10_config
if args.architecture=='vgg16' and args.dataset=='cifar10':
  config=vgg16_cifar10_config
if args.architecture=='vgg16' and args.dataset=='cifar100':
  config=vgg16_cifar100_config
if args.architecture=='vgg19':
  config=vgg19_config
if args.architecture=='resnet18' and args.dataset=='cifar10':
  config=resnet18_cifar10_config
if args.architecture=='resnet18' and args.dataset=='cifar100':
  config=resnet18_cifar100_config
if args.architecture=='resnet18' and args.dataset=='tinyimagenet':
  config=resnet18_tinyimagenet_config

def main(args):
  target_compression=10**args.com_exp if args.com_exp is not None else 1./(1-args.target_sparsity)
  extension=f"{args.sample}_{config['data']}_{round(target_compression)}_"
  path_to_dense=os.path.join(args.out_path,args.architecture,'dense')
  args.out_path=os.path.join(args.out_path,args.architecture,args.pruner,args.pruning_type)
  if not os.path.exists(args.out_path):
    os.makedirs(args.out_path)
  if not os.path.exists(path_to_dense):
    os.makedirs(path_to_dense)
  logging.basicConfig(filename=os.path.join(args.out_path,extension+'info.log'),level=logging.INFO,filemode='w')
  datagen,train_X,train_y,test_X,test_y=data.get_data(config['data'],path_to_data=args.path_to_data)
  epochs=int(config['batch_size_train']*config['iterations']/len(train_X))
  model,tensors=models.get_model(shape=train_X[0].shape,architecture=args.architecture,batchnorm=config['batchnorm'],decay=config['weight_decay'],output_classes=len(train_y[0]))
  values=[config['lr']*(0.1**i) for i in range(len(config['lr_decay'])+1)]
  #learningrate=tf.keras.optimizers.schedules.PiecewiseConstantDecay(config['lr_decay'],values)
  learningrate=tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=config['lr'],decay_steps=config['iterations'],alpha=0)
  model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=learningrate,momentum=config['momentum']),loss='categorical_crossentropy',metrics=['accuracy'])
  log_list=np.arange(0,config['iterations'],1000)
  pruner=pruning.Pruner(args.pruner)
  masks=pruner.prune(model,tensors,1-1./target_compression,args.pruning_type,train_X=train_X,train_y=train_y,out_path=os.path.join(args.out_path,extension),config=config,sample=args.sample,dataset=args.dataset,path_to_dense=path_to_dense)
  inits=[model.layers[layer].get_weights()[0] for layer in tensors]
  log_cb=callbacks.LogCallback(model,tensors,masks,log_list,(test_X,test_y))
  fit_callbacks=[callbacks.SubnetworkCallback(model,tensors,masks),log_cb]
  eff_masks_custom=effective_masks_custom(model.name,masks)
  eff_masks_synflow=effective_masks_synflow(model,tensors,masks)
  logging.info(f'<main> [direct sparsity: {get_overall_direct_sparsity(masks):.6f}][effective sparsity: {get_overall_direct_sparsity(eff_masks_synflow):.6f}][epochs to train: {epochs}][iterations to train: {config["iterations"]}][pruner: {args.pruner}][sample: {args.sample}]')
  if args.save and args.pruner!='dense':
    np.save(os.path.join(args.out_path,extension)+'sparsities_effective_synflow.npy',get_direct_sparsity(eff_masks_synflow))
    np.save(os.path.join(args.out_path,extension)+'sparsities_effective_custom.npy',get_direct_sparsity(eff_masks_custom))
    np.save(os.path.join(args.out_path,extension)+'sparsities_direct.npy',get_direct_sparsity(masks))
    np.save(os.path.join(path_to_dense,'counts.npy'),[np.prod(model.layers[layer].get_weights()[0].shape) for layer in tensors])
  if args.train:
    model.fit(datagen.flow(train_X,train_y,batch_size=config['batch_size_train']),steps_per_epoch=len(train_X)//config['batch_size_train'],epochs=epochs,shuffle=True,verbose=True,validation_data=(test_X,test_y),callbacks=fit_callbacks)
    np.save(os.path.join(args.out_path,extension)+'accuracies.npy',log_cb.accuracies)
    np.save(os.path.join(args.out_path,extension)+'losses.npy',log_cb.losses)
  if args.save and args.train and args.pruner=='dense':
    np.save(os.path.join(args.out_path,extension)+'inits.npy',inits)
    np.save(os.path.join(args.out_path,extension)+'final_weights.npy',log_cb.final_weights)

if __name__=="__main__":
  main(args)
