import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
try:
    for d in physical_devices:
        tf.config.experimental.set_memory_growth(d, True)
except:
    pass


import sys
import random
import numpy as np
import os
import time
import pickle
from MultitaskGPModel import MultitaskGPModel
import datetime
from gpytorch.distributions import MultivariateNormal
from blist import blist
import bisect

sync_dir = sys.argv[1]
sparsity = float(sys.argv[2])
sparsity_block = float(sys.argv[3])
penalty = float(sys.argv[4])

model = 'resnet_v1_50'


def _pick_submodular_inner_memoize(samples, curr, cand, limit, memo = None, toadd = None):
    #this function  maintains a sorted list and is used to efficiently compute the utility
    num_dims = samples.shape[1]
        
    taken = set(curr)
    left = list(set([i for i in range(num_dims)]) - taken)
    
    if limit > len(curr):
        limit = len(curr)
    if limit == 0:
        improved_dict = {}
        for idx, cand in enumerate(left):
            imps = [s[cand] for s in samples]
            improved_dict[cand] = np.mean(imps)
        return improved_dict, memo
    
    if memo is None:
        ssorted = []
        for s in samples:
            ss = blist(np.sort(-1*s[curr])[0:limit])
            ssorted.append(ss)
    else:
        ssorted = memo
        ssorted_new = []
        if toadd is not None:
            for i, s in enumerate(samples):
                scand = -1*s[toadd]
                ss = ssorted[i]
                bisect.insort(ss, scand)
                ssorted_new.append(ss[0:limit])
        ssorted = ssorted_new

    imps = [max(0, s[cand] + ssorted[i][limit-1]) for i, s in enumerate(samples)]
    
    return np.mean(imps), ssorted

def _do_prune(bs, penalty, rem_epochs, compute_rem, samples):
    keep = []
    improvements, memo = _pick_submodular_inner_memoize(samples, [], None, bs)
    imp_sorted = [(k,v) for k, v in sorted(improvements.items(), key=lambda item: item[1])][::-1]
    
    crem_cntr = compute_rem
    for val in imp_sorted[0:bs]:
        keep.append(val[0])
        crem_cntr -= rem_epochs
    toadd = None
    while len(keep) < len(imp_sorted):
        cand = imp_sorted[len(keep)][0]
        candval, memo = _pick_submodular_inner_memoize(samples, keep, cand, bs, memo, toadd)
        # dynamic penalty
        penalty_scaled = np.power((1.0 / penalty),((rem_epochs * len(keep))/max(compute_rem, .01)) - 1.0 )
        if (crem_cntr - rem_epochs) > 0:
            keep.append(cand)
        elif candval > (penalty_scaled * penalty * min(-1 * (crem_cntr - rem_epochs), rem_epochs)):
            keep.append(cand)
        else:
            break
        toadd = cand
        crem_cntr -= rem_epochs

    fullset = set([vi for vi in range(len(imp_sorted))])

    prune = fullset - set(keep)

    return prune

import subprocess
def exec_blocking(cmd):
    p = subprocess.Popen(cmd.split(' '))
    p.wait()
    return None

class resnet_earlyprune(object):
    def __init__(self, bs, penalty, bc, steps = 600000):
        self.bs = bs
        self.bc = bc 
        self.penalty = penalty
        self.steps = steps
        self.prune_dict = {}
        self.raw_data = []
        self.layers = None
        self.mogp_models = None
        self.bs_applied = None
        
    def do_prune(self, rem_epochs, compute_rem, next_epoch_prune):        
        all_samples = []
        filter_inverse_map = {}
        prune_dict = {}
        for idx, mdl in enumerate(self.mogp_models):
            qi = mdl.predict_y([800], full_cov = True, full_output_cov = True)
            ls = self.layer_sizes[idx]
            mean = qi[0][:, 0:ls].reshape(-1)
            cov = qi[1][:, 0:ls, :, 0:ls].reshape((mean.shape[0], mean.shape[0]))
            try:
                samples =  np.random.multivariate_normal(mean, cov, 15000)
            except:
                raise
            to_prune = _do_prune(int(self.bs_applied[idx]), self.penalty, rem_epochs, self.compute_rem[idx], samples)
            prune_idx = sorted(to_prune)
            mdl.prune(prune_idx)
            prune_dict[self.layers[idx]] = (self.layer_sizes[idx], self.layer_sizes[idx] - len(prune_idx), prune_idx)
            self.compute_rem[idx] -= next_epoch_prune*(self.layer_sizes[idx] - len(prune_idx))
        self.prune_dict = prune_dict
    
    def save_prune_file(self, curr_step):
        f = os.path.join(sync_dir, 'prune_%s_%d.p' % (model, curr_step))
        pickle.dump(self.prune_dict, open(f, 'wb'), protocol = 2)
        return f
    
    def instrument_file(self, final_step):
        f = os.path.join(sync_dir, 'instrument_file_%d.p' % final_step)
        return f
    
    def cost_file(self, final_step):
        f = os.path.join(sync_dir, 'cost_file_%d.p' % final_step)
        return f

    def get_batch_size(self, curr_step):
        return 64

    def get_learning_rate(self, step):
        if step < (5000*30):
            return 0.1
        elif step < (5000*60):
            return 0.01
        elif step < (5000*80):
            return 0.001
        elif step < (5000*90):
            return 1e-4
        else:
            return 1e-5

    def get_cmd_str(self, final_step, curr_step, steps, checkpoint_path = None, prune_file = None):
        batch_size = self.get_batch_size(curr_step)
        if curr_step != 0 and prune_file is None:
            prune_file = self.save_prune_file(curr_step)
        instrument_file = self.instrument_file(final_step)
        lr = self.get_learning_rate(curr_step)
        
        if checkpoint_path is not None:
            cmdline = 'python -u ./models/research/slim/train_image_classifier.py --num_clones 4 --optimizer momentum --learning_rate_decay_type fixed --log_every_n_steps 200 --num_readers 16 --num_preprocessing_threads 16 --dataset_dir /tmp/imgnet --model_name %s --max_number_of_steps %d --train_dir /tmp/tfmodel_PYTHONRUN --batch_size %d --prune_mask %s --instrument_file %s --learning_rate %f --checkpoint_path %s &> /dev/null'
            return cmdline % (model, steps, batch_size, prune_file, instrument_file, lr, checkpoint_path)
        else:
            if prune_file is None:
                cmdline = 'python -u ./models/research/slim/train_image_classifier.py --num_clones 4 --learning_rate_decay_type fixed --optimizer momentum --log_every_n_steps 200 --num_readers 16 --num_preprocessing_threads 16 --dataset_dir /tmp/imgnet --model_name %s --max_number_of_steps %d --train_dir /tmp/tfmodel_PYTHONRUN --batch_size %d --instrument_file %s --learning_rate %f &> /dev/null'
                return cmdline % (model, steps, batch_size, instrument_file, lr)
            else:
                cmdline = 'python -u ./models/research/slim/train_image_classifier.py --num_clones 4 --optimizer momentum --learning_rate_decay_type fixed --log_every_n_steps 200 --num_readers 16 --num_preprocessing_threads 16 --dataset_dir /tmp/imgnet --model_name %s --max_number_of_steps %d --train_dir /tmp/tfmodel_PYTHONRUN --batch_size %d --prune_mask %s --instrument_file %s --learning_rate %f &> /dev/null'
                return cmdline % (model, steps, batch_size, prune_file, instrument_file, lr)

    def collect_data(self, step, prune_built = None):
        instr_file = self.instrument_file(step)
        dat = pickle.load(open(instr_file, 'rb'), encoding = 'latin1')
        
        dat_pre = []
        

        # separate layers into 'normal' layers and block pruning layers
        block_layers = {}
        for idx, di in enumerate(dat[0]):
            if di.endswith('shortcut') or di.endswith('bottleneck_v1'):
                unit = int(di.split('/')[2].split('_')[1])
                block = int(di.split('/')[1].split('block')[1])
                key_str = di.split('/unit_')[0]
                if key_str not in block_layers.keys():
                    block_layers[key_str] = []
                block_layers[key_str].append(idx)
            else:
                block_layers[di] = [idx]


        #unpickle the data and build grad dataset
        dat1_builder = []
        layers_sizes = {}
        for step, grad in dat[1]:
            layer_grad_builder = []
            for k in sorted(block_layers.keys()):
                v = block_layers[k]
                grad_accum = []
                for vi in v:
                    grad_ = grad[vi]
                    grad_accum.append(grad_)
                grad_accum = np.array(grad_accum)
                grad_accum = np.sum(grad_accum, axis = 0)
                layer_grad_builder.append(grad_accum)
                layers_sizes[k] = grad_accum.shape[0]
            dat1_builder.append((step, layer_grad_builder))
        
        self.layers = [k for k in sorted(block_layers.keys())]
        self.layer_mults = [len(block_layers[k]) for k in sorted(block_layers.keys())]
        
        if self.mogp_models is None:
            self.mogp_models = [MultitaskGPModel(ls, opt_max_iter = 1000, opt_tol = 9.5, L = 4, num_inducing_points = 60) for ls in self.layer_sizes]
            
        for step_cntr, grads in dat1_builder:
            for idx, grad in enumerate(grads):
                self.mogp_models[idx].append_data(grad)
         

        if self.bs_applied is None:
            self.mults = []
            for idx, lmult in enumerate(self.layer_mults):
                if lmult > 1:
                    self.mults.append(sparsity_block/self.bs)
                else:
                    self.mults.append(1.0)
            self.bs_applied = [int(self.bs * size*self.mults[i]) for i, size in enumerate(self.layer_sizes)]
            self.compute_rem = [85 * bsi for bsi in self.bs_applied]

    def compute_step(self):
        for idx, mi in enumerate(self.mogp_models):
            mi.compute_step()

exec_blocking('rm -rf /tmp/tfmodel_PYTHONRUN')
layers=[]
resnet_ = resnet_earlyprune(7,penalty, 1.0)
resnet_.bs = sparsity


exec_blocking('rm -rf /tmp/tfmodel_PYTHONRUN')

prune_epochs =   [15, 20, 25, 35, 45, 55, 75, 100]
process_epochs = [15, 20, 25, 30, 35, 45, 55, 60, 75, 80, 100]
ep_run =         [15, 5, 5, 5, 10, 10, 10, 5, 20, 5, 25]

#main loop. call train. collect data. do BEP. resume train
for i, pe in enumerate(process_epochs):
    epr = ep_run[i]
    if i == 0:
        cmd_str = resnet_.get_cmd_str(pe*5000, 0, epr*5000, checkpoint_path = None)
    else:
        peprev = process_epochs[i-1]
        eprprev = ep_run[i-1]
        checkpoint_path = os.path.join(sync_dir, 'tfmodel_PYTHONRUN_%d/model.ckpt-%d' % (peprev*5000, eprprev*5000))
        cmd_str = resnet_.get_cmd_str(pe*5000, peprev*5000, epr*5000, checkpoint_path)

    exec_blocking('./launchpad.sh %s' % cmd_str)

    if pe == 100:
        exec_blocking('cp -r /tmp/tfmodel_PYTHONRUN/ ' + os.path.join(sync_dir, 'tfmodel_PYTHONRUN_%d' % (pe*5000)))
        exec_blocking('rm -rf /tmp/tfmodel_PYTHONRUN')
        break

    resnet_.collect_data(pe*5000)

    if pe in prune_epochs:
        exec_blocking('cp -r /tmp/tfmodel_PYTHONRUN/ ' + os.path.join(sync_dir, 'tfmodel_PYTHONRUN_%d' % (pe*5000)))
        exec_blocking('rm -rf /tmp/tfmodel_PYTHONRUN')
        ci = prune_epochs.index(pe)
        ni = prune_epochs[ci + 1]
        resnet_.compute_step()
        resnet_.do_prune(100-pe, resnet_.bs*85, ni-pe)

