from __future__ import division

import os, time, random, math
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import models
from utils.options import args
import utils.common as utils
import numpy as np
import heapq
from data import cub200, cifar100
from importlib import import_module
from thop import profile
import copy
from torch.utils.tensorboard import SummaryWriter
from methods import networkslimming, epruner, depgraph

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss_func = nn.CrossEntropyLoss()

args.use_cuda = args.gpus is not None and torch.cuda.is_available()

if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if args.use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)
cudnn.deterministic = True
cudnn.benchmark = False

compress_rate_str = args.compress_rate.replace('*', ']')
compress_rate_str = compress_rate_str.replace('+', '+[')
compress_rate_str = '[' + compress_rate_str

if args.train_slim:
    args.job_dir = os.path.join(args.job_dir, 'SP^rT', 'lr' + str(args.lr),
                                compress_rate_str)
elif not args.use_pretrain:
    if args.transfer and args.hard_inherit:
        if args.prune_rule == 'NS_pretrain':
            args.job_dir = os.path.join(args.job_dir, 'STP^wT',
                                        args.prune_rule, 'lr' + str(args.lr),
                                        'PR_' + str(args.channel_PR))
        elif args.prune_rule == 'epruner_pretrain':
            args.job_dir = os.path.join(args.job_dir, 'STP^wT',
                                        args.prune_rule, 'lr' + str(args.lr),
                                        'beta_' + str(args.preference_beta))
        elif args.prune_rule == 'depgraph_pretrain':
            args.job_dir = os.path.join(
                args.job_dir, 'STP^wT', args.prune_rule, 'lr' + str(args.lr),
                'flops_PR_' + str(args.target_flops_PR))
        else:
            args.job_dir = os.path.join(args.job_dir, 'STP^wT',
                                        args.prune_rule, 'lr' + str(args.lr),
                                        compress_rate_str)
    else:
        if args.prune_rule == 'NS_pretrain':
            args.job_dir = os.path.join(args.job_dir, 'TP^wT',
                                        args.prune_rule,
                                        'lr' + str(args.lr),
                                        'PR_' + str(args.channel_PR))
        elif args.prune_rule == 'epruner_pretrain':
            args.job_dir = os.path.join(
                args.job_dir, 'TP^wT', args.prune_rule,
                'lr' + str(args.lr), 'beta_' + str(args.preference_beta))
        elif args.prune_rule == 'depgraph_pretrain':
            args.job_dir = os.path.join(
                args.job_dir, 'TP^wT', args.prune_rule,
                'lr' + str(args.lr),
                'flops_PR_' + str(args.target_flops_PR))
        else:
            args.job_dir = os.path.join(args.job_dir, 'TP^wT',
                                        args.prune_rule,
                                        'lr' + str(args.lr),
                                        compress_rate_str)
elif args.transfer:
    args.job_dir = os.path.join(args.job_dir, 'ST', 'lr' + str(args.lr))

elif args.hard_inherit:
    if args.prune_rule == 'NS_pretrain':
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr),
                                    'PR_' + str(args.channel_PR))
    elif args.prune_rule == 'epruner_pretrain':
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr),
                                    'beta_' + str(args.preference_beta))
    elif args.prune_rule == 'depgraph_pretrain':
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr),
                                    'flops_PR_' + str(args.target_flops_PR))
    else:
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr), compress_rate_str)

checkpoint = utils.checkpoint(args)

if not os.path.exists(args.job_dir):
    os.makedirs(args.job_dir)

get_logger = utils.get_logger()
get_logger.add_logger(os.path.join(args.job_dir, 'logger.log'))
logger = get_logger.logger


def adjust_learning_rate(optimizer, epoch, args, lr_decay_step, epochs):
    if args.lr_type == 'step':
        factor = 0
        for i in range(1, epoch + 1):
            if i in lr_decay_step:
                factor += 1

        lr = args.lr * (0.1**factor)
    elif args.lr_type == 'cos':
        lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / epochs))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


# Training
def train(model, optimizer, loader_train, args, epoch, lr_decay_step, epochs):

    model.train()
    losses = utils.AverageMeter()
    accuracy = utils.AverageMeter()
    print_freq = len(loader_train.dataset) // args.train_batch_size // 10
    start_time = time.time()
    start_time_epoch = start_time
    for batch, (inputs, targets) in enumerate(loader_train):

        inputs = inputs.to(device)
        targets = targets.to(device)

        adjust_learning_rate(optimizer, epoch, args, lr_decay_step, epochs)

        output = model(inputs)
        loss = loss_func(output, targets)
        optimizer.zero_grad()
        loss.backward()
        losses.update(loss.item(), inputs.size(0))
        optimizer.step()

        prec1 = utils.accuracy(output, targets)
        accuracy.update(prec1[0], inputs.size(0))

        if batch % print_freq == 0 and batch != 0:
            current_time = time.time()
            cost_time = current_time - start_time
            logger.info('Epoch[{}] ({}/{}):\t'
                        'Learning Rate {:.6f}\t'
                        'Loss {:.4f}\t'
                        'Top1 {:.3f}%\t'
                        'Time {:.2f}s'.format(
                            epoch, batch * args.train_batch_size,
                            len(loader_train.dataset),
                            float(optimizer.param_groups[0]['lr']),
                            float(losses.avg), float(accuracy.avg), cost_time))
            start_time = current_time

    time_per_epoch.update(time.time() - start_time_epoch, 1)

    return accuracy.avg, losses.avg


# Testing
def test(model, loader_test, best_acc):
    model.eval()

    losses = utils.AverageMeter()
    accuracy = utils.AverageMeter()

    start_time = time.time()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader_test):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            predicted = utils.accuracy(outputs, targets)
            accuracy.update(predicted[0], inputs.size(0))

        current_time = time.time()
        logger.info(
            'Test Loss {:.4f}\tTop1 {:.3f}% / {:.3f}%\tTime {:.2f}s\n'.format(
                float(losses.avg), float(accuracy.avg), float(best_acc),
                (current_time - start_time)))

    return accuracy.avg, losses.avg


def load_resnet(model, prune_rule, init_state_dict=None):
    global oristate_dict
    cfg = {'resnet50': [3, 4, 6, 3]}
    current_cfg = cfg[args.cfg]
    block_last_name = []
    for layer, num in enumerate(current_cfg):
        block_last_name.append('layer{}.{}.conv3'.format(
            str(layer + 1), str(num - 1)))
    del block_last_name[-1]

    state_dict = model.state_dict()

    last_select_index = None
    block_select_index = None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            conv_weight_name = name + '.weight'
            oriweight = oristate_dict[conv_weight_name]
            curweight = state_dict[conv_weight_name]
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)

            if orifilter_num != currentfilter_num and (
                    prune_rule == 'random_pretrain' or prune_rule
                    == 'l1_pretrain' or prune_rule == 'NS_pretrain'):

                select_num = currentfilter_num
                if prune_rule == 'random_pretrain':
                    select_index = random.sample(range(0, orifilter_num - 1),
                                                 select_num)
                    select_index.sort()
                elif prune_rule == "l1_pretrain":
                    l1_sum = list(torch.sum(torch.abs(oriweight), [1, 2, 3]))
                    select_index = list(
                        map(l1_sum.index,
                            heapq.nlargest(currentfilter_num, l1_sum)))
                    select_index.sort()
                elif prune_rule == 'NS_pretrain':
                    if 'conv' in conv_weight_name:
                        bn_weight_name = conv_weight_name.replace('conv', 'bn')
                    elif 'downsample' in conv_weight_name:
                        bn_weight_name = conv_weight_name.replace('0', '*', 1)
                        bn_weight_name = bn_weight_name.replace('0', '1', 1)
                        bn_weight_name = bn_weight_name.replace('*', '0', 1)
                    rank = list(
                        map(
                            lambda x: x.item(),
                            torch.argsort(
                                oristate_dict[bn_weight_name].abs().clone())))
                    select_index = rank[::-1][:select_num]
                    select_index.sort()

                if ('layer' not in conv_weight_name
                        and 'conv1' in conv_weight_name) or (
                            block_last_name[0] in conv_weight_name) or (
                                block_last_name[1]
                                in conv_weight_name) or (block_last_name[2]
                                                         in conv_weight_name):
                    block_select_index = select_index

                if 'downsample' not in conv_weight_name:
                    if last_select_index is not None:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        oristate_dict[conv_weight_name][i][j]
                        else:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        init_state_dict[conv_weight_name][i][j]
                    else:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    oristate_dict[conv_weight_name][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    init_state_dict[conv_weight_name][i]

                    last_select_index = select_index
                else:
                    if block_select_index is not None:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(
                                        block_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        oristate_dict[conv_weight_name][i][j]
                        else:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(
                                        block_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        init_state_dict[conv_weight_name][i][j]
                    else:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    oristate_dict[conv_weight_name][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    init_state_dict[conv_weight_name][i]
                    block_select_index = None

            elif 'conv' in conv_weight_name and last_select_index is not None:
                if init_state_dict is None:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                oristate_dict[conv_weight_name][index_i][j]
                else:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                init_state_dict[conv_weight_name][index_i][j]
                last_select_index = None

            elif 'downsample' in conv_weight_name and block_select_index is not None:
                if init_state_dict is None:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(block_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                oristate_dict[conv_weight_name][index_i][j]
                else:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(block_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                init_state_dict[conv_weight_name][index_i][j]
                block_select_index = None

            else:
                if init_state_dict is None:
                    state_dict[conv_weight_name] = oriweight
                    last_select_index = None
                else:
                    state_dict[conv_weight_name] = init_state_dict[
                        conv_weight_name]
                    last_select_index = None
        elif isinstance(module, nn.BatchNorm2d):
            bn_weight_name = name + '.weight'
            bn_bias_name = bn_weight_name.replace('weight', 'bias')
            bn_rm_name = bn_weight_name.replace('weight', 'running_mean')
            bn_rv_name = bn_weight_name.replace('weight', 'running_var')
            oriweight = oristate_dict[bn_weight_name]
            curweight = state_dict[bn_weight_name]
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)

            if last_select_index is None:
                if init_state_dict is None:
                    state_dict[bn_weight_name] = oristate_dict[bn_weight_name]
                    state_dict[bn_bias_name] = oristate_dict[bn_bias_name]
                    state_dict[bn_rm_name] = oristate_dict[bn_rm_name]
                    state_dict[bn_rv_name] = oristate_dict[bn_rv_name]
                else:
                    state_dict[bn_weight_name] = init_state_dict[
                        bn_weight_name]
                    state_dict[bn_bias_name] = init_state_dict[bn_bias_name]
                    state_dict[bn_rm_name] = init_state_dict[bn_rm_name]
                    state_dict[bn_rv_name] = init_state_dict[bn_rv_name]
            else:
                if init_state_dict is None:
                    for index_i, i in enumerate(last_select_index):
                        state_dict[bn_weight_name][index_i] = oristate_dict[
                            bn_weight_name][i]
                        state_dict[bn_bias_name][index_i] = oristate_dict[
                            bn_bias_name][i]
                        state_dict[bn_rm_name][index_i] = oristate_dict[
                            bn_rm_name][i]
                        state_dict[bn_rv_name][index_i] = oristate_dict[
                            bn_rv_name][i]
                else:
                    for index_i, i in enumerate(last_select_index):
                        state_dict[bn_weight_name][index_i] = init_state_dict[
                            bn_weight_name][i]
                        state_dict[bn_bias_name][index_i] = init_state_dict[
                            bn_bias_name][i]
                        state_dict[bn_rm_name][index_i] = init_state_dict[
                            bn_rm_name][i]
                        state_dict[bn_rv_name][index_i] = init_state_dict[
                            bn_rv_name][i]
        elif isinstance(module, nn.Linear):
            if init_state_dict is None:
                state_dict[name + '.weight'] = oristate_dict[name + '.weight']
                state_dict[name + '.bias'] = oristate_dict[name + '.bias']
            else:
                state_dict[name + '.weight'] = init_state_dict[name +
                                                               '.weight']
                state_dict[name + '.bias'] = init_state_dict[name + '.bias']

    model.load_state_dict(state_dict)


def load_resnet_hrank(model, init_state_dict=None):
    global oristate_dict
    cfg = {'resnet50': [3, 4, 6, 3]}

    state_dict = model.state_dict()

    current_cfg = cfg[args.cfg]
    last_select_index = None

    all_honey_conv_weight = []

    bn_part_name = ['.weight', '.bias', '.running_mean', '.running_var']
    prefix = f'{args.rank_path}/rank_conv'
    subfix = ".npy"
    cnt = 1

    conv_weight_name = 'conv1.weight'
    all_honey_conv_weight.append(conv_weight_name)
    oriweight = oristate_dict[conv_weight_name]
    curweight = state_dict[conv_weight_name]
    orifilter_num = oriweight.size(0)
    currentfilter_num = curweight.size(0)

    if orifilter_num != currentfilter_num:
        logger.info('loading rank from: ' + prefix + str(cnt) + subfix)
        rank = np.load(prefix + str(cnt) + subfix)
        select_index = np.argsort(rank)[
            orifilter_num - currentfilter_num:]  # preserved filter id
        select_index.sort()

        if init_state_dict is None:
            for index_i, i in enumerate(select_index):
                state_dict[conv_weight_name][index_i] = \
                    oristate_dict[conv_weight_name][i]
                for bn_part in bn_part_name:
                    state_dict['bn1' + bn_part][index_i] = \
                        oristate_dict['bn1' + bn_part][i]
        else:
            for index_i, i in enumerate(select_index):
                state_dict[conv_weight_name][index_i] = \
                    init_state_dict[conv_weight_name][i]
                for bn_part in bn_part_name:
                    state_dict['bn1' + bn_part][index_i] = \
                        init_state_dict['bn1' + bn_part][i]

        last_select_index = select_index
    else:
        if init_state_dict is None:
            state_dict[conv_weight_name] = oriweight
            for bn_part in bn_part_name:
                state_dict['bn1' + bn_part] = oristate_dict['bn1' + bn_part]
        else:
            state_dict[conv_weight_name] = init_state_dict[conv_weight_name]
            for bn_part in bn_part_name:
                state_dict['bn1' + bn_part] = init_state_dict['bn1' + bn_part]

    if init_state_dict is None:
        state_dict['bn1' + '.num_batches_tracked'] = oristate_dict[
            'bn1' + '.num_batches_tracked']
    else:
        state_dict['bn1' + '.num_batches_tracked'] = init_state_dict[
            'bn1' + '.num_batches_tracked']

    cnt += 1
    for layer, num in enumerate(current_cfg):
        layer_name = 'layer' + str(layer + 1) + '.'

        for k in range(num):
            if args.cfg == 'resnet50':
                iter = 3
            else:
                raise ValueError("arsgs.cfg must be resnet50")
            if k == 0:
                iter += 1
            for l in range(iter):
                record_last = True
                if k == 0 and l == 2:
                    conv_name = layer_name + str(k) + '.downsample.0'
                    bn_name = layer_name + str(k) + '.downsample.1'
                    record_last = False
                elif k == 0 and l == 3:
                    conv_name = layer_name + str(k) + '.conv' + str(l)
                    bn_name = layer_name + str(k) + '.bn' + str(l)
                else:
                    conv_name = layer_name + str(k) + '.conv' + str(l + 1)
                    bn_name = layer_name + str(k) + '.bn' + str(l + 1)

                conv_weight_name = conv_name + '.weight'
                all_honey_conv_weight.append(conv_weight_name)
                oriweight = oristate_dict[conv_weight_name]
                curweight = state_dict[conv_weight_name]
                orifilter_num = oriweight.size(0)
                currentfilter_num = curweight.size(0)

                if orifilter_num != currentfilter_num:
                    logger.info('loading rank from: ' + prefix + str(cnt) +
                                subfix)
                    rank = np.load(prefix + str(cnt) + subfix)
                    select_index = np.argsort(
                        rank)[orifilter_num -
                              currentfilter_num:]  # preserved filter id
                    select_index.sort()

                    if last_select_index is not None:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        oristate_dict[conv_weight_name][i][j]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        oristate_dict[bn_name + bn_part][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        init_state_dict[conv_weight_name][i][j]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        init_state_dict[bn_name + bn_part][i]

                    else:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    oristate_dict[conv_weight_name][i]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        oristate_dict[bn_name + bn_part][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    init_state_dict[conv_weight_name][i]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        init_state_dict[bn_name + bn_part][i]

                    if record_last:
                        last_select_index = select_index

                elif last_select_index is not None:
                    if init_state_dict is None:
                        for index_i in range(orifilter_num):
                            for index_j, j in enumerate(last_select_index):
                                state_dict[conv_weight_name][index_i][index_j] = \
                                    oristate_dict[conv_weight_name][index_i][j]

                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                oristate_dict[bn_name + bn_part]
                    else:
                        for index_i in range(orifilter_num):
                            for index_j, j in enumerate(last_select_index):
                                state_dict[conv_weight_name][index_i][index_j] = \
                                    init_state_dict[conv_weight_name][index_i][j]

                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                init_state_dict[bn_name + bn_part]

                    if record_last:
                        last_select_index = None

                else:
                    if init_state_dict is None:
                        state_dict[conv_weight_name] = oriweight
                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                oristate_dict[bn_name + bn_part]
                    else:
                        state_dict[conv_weight_name] = init_state_dict[
                            conv_weight_name]
                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                init_state_dict[bn_name + bn_part]
                    if record_last:
                        last_select_index = None

                if init_state_dict is None:
                    state_dict[bn_name +
                               '.num_batches_tracked'] = oristate_dict[
                                   bn_name + '.num_batches_tracked']
                else:
                    state_dict[bn_name +
                               '.num_batches_tracked'] = init_state_dict[
                                   bn_name + '.num_batches_tracked']
                cnt += 1

    for name, module in model.named_modules():
        name = name.replace('module.', '')
        if isinstance(module, nn.Conv2d):
            conv_name = name + '.weight'
            if conv_name not in all_honey_conv_weight:
                if init_state_dict is None:
                    state_dict[conv_name] = oristate_dict[conv_name]
                else:
                    state_dict[conv_name] = init_state_dict[conv_name]

        elif isinstance(module, nn.Linear):
            if init_state_dict is None:
                state_dict[name + '.weight'] = oristate_dict[name + '.weight']
                state_dict[name + '.bias'] = oristate_dict[name + '.bias']
            else:
                state_dict[name + '.weight'] = init_state_dict[name +
                                                               '.weight']
                state_dict[name + '.bias'] = init_state_dict[name + '.bias']

    model.load_state_dict(state_dict)


def main():
    global oristate_dict, time_per_epoch

    # Model
    if args.use_pretrain:
        print('==> Loading Pretrained Model..')
        origin_model = models.resnet50(pretrained=True)
    else:
        print('==> Loading Model..')
        origin_model = models.resnet50(pretrained=False)

    input_image = torch.randn(1, 3, 224, 224).to(device)
    full_flops, full_params = profile(origin_model.to(device),
                                      inputs=(input_image, ))
    logger.info('full_flops: {}\tfull_params: {}\n'.format(
        full_flops, full_params))

    # Data
    print('==> Loading Data..')
    if args.data_set == 'cub200':
        loader = cub200.Data(args)
        input_image_size = 224
        num_classes = 200
        num_fc_in = origin_model.fc.in_features
        origin_model.fc = nn.Linear(num_fc_in, num_classes)
        origin_model = origin_model.to(device)
        oristate_dict = origin_model.state_dict()
    elif args.data_set == 'cifar100':
        loader = cifar100.Data(args)
        input_image_size = 224
        num_classes = 100
        num_fc_in = origin_model.fc.in_features
        origin_model.fc = nn.Linear(num_fc_in, num_classes)
        origin_model = origin_model.to(device)
        oristate_dict = origin_model.state_dict()

    # calculate the parameters and FLOPs of original full-size model
    input_image = torch.randn(1, 3, input_image_size,
                              input_image_size).to(device)
    full_flops, full_params = profile(origin_model, inputs=(input_image, ))
    logger.info('full_flops: {}\tfull_params: {}\n'.format(
        full_flops, full_params))

    model = copy.deepcopy(origin_model)
    logger.info('Full model: {}'.format(model))

    lr_decay_step_pretrain = list(
        map(int, args.lr_decay_step_pretrain.split(',')))

    # compression rate
    if args.compress_rate:
        import re
        cprate_str = args.compress_rate
        cprate_str_list = cprate_str.split('+')
        pat_cprate = re.compile(r'\d+\.\d*')
        pat_num = re.compile(r'\*\d+')
        cprate = []
        for x in cprate_str_list:
            num = 1
            find_num = re.findall(pat_num, x)
            if find_num:
                assert len(find_num) == 1
                num = int(find_num[0].replace('*', ''))
            find_cprate = re.findall(pat_cprate, x)
            assert len(find_cprate) == 1
            cprate += [float(find_cprate[0]) * 10] * num

        compress_rate = cprate
        print('compress_rate:', compress_rate)
    # pretrained on B (target-domain dataset)
    if (not args.use_pretrain) and (not args.resume_finetune) and (
            not args.train_slim):
        print('==> Pretrain Model on Target-domain dataset..')

        # set optimizer
        optimizer_train = optim.SGD(model.parameters(),
                                    lr=args.lr_train,
                                    momentum=args.momentum,
                                    weight_decay=args.train_weight_decay)

        best_acc_train = 0
        start_epoch = 0
        time_per_epoch = utils.AverageMeter()

        if args.resume_pretrain:
            # Model
            resumeckpt = torch.load(args.resume_pretrain)
            model.load_state_dict(resumeckpt['state_dict'], strict=False)
            optimizer_train.load_state_dict(resumeckpt['optimizer'])
            start_epoch = resumeckpt['epoch']
            best_acc_train = resumeckpt['best_acc']
            time_per_epoch = resumeckpt['time_per_epoch']

        if len(args.gpus) > 1:
            device_id = []
            for i in range((len(args.gpus) + 1) // 2):
                device_id.append(i)
            model = nn.DataParallel(model, device_ids=device_id).to(device)

        if not os.path.exists(f'{args.job_dir}/pretrained_model/'):
            os.makedirs(f'{args.job_dir}/pretrained_model/')

        # tensorboard
        logdir_pretrain = Path(args.job_dir) / 'logs_pretrain'
        logdir_pretrain.mkdir(parents=True, exist_ok=True)
        summary_writer_pretrain = SummaryWriter(logdir_pretrain,
                                                flush_secs=120)

        for epoch in range(start_epoch, args.train_epochs):
            # store the initialization weights of full-size model
            if start_epoch == 0:
                model_state_dict = model.module.state_dict() if len(
                    args.gpus) > 1 else model.state_dict()
                torch.save(
                    model_state_dict,
                    f'{args.job_dir}/pretrained_model/model_initial.pt')

            train_acc, train_los = train(model, optimizer_train,
                                         loader.loader_train, args, epoch,
                                         lr_decay_step_pretrain,
                                         args.train_epochs)
            test_acc, test_los = test(model, loader.loader_test,
                                      best_acc_train)
            logger.info('Time per epoch for training: {:.2f}'.format(
                float(time_per_epoch.avg)))

            is_best = best_acc_train < test_acc
            best_acc_train = max(best_acc_train, test_acc)

            model_state_dict = model.module.state_dict() if len(
                args.gpus) > 1 else model.state_dict()

            state = {
                'state_dict': model_state_dict,
                'best_acc': best_acc_train,
                'optimizer': optimizer_train.state_dict(),
                'epoch': epoch + 1,
                'time_per_epoch': time_per_epoch
            }

            save_path = f'{args.job_dir}/pretrained_model/model.pt'
            torch.save(state, save_path)
            if is_best:
                torch.save(state,
                           f'{args.job_dir}/pretrained_model/model_best.pt')

            # tensorboard
            if epoch == 1:
                images, labels = next(iter(loader.loader_train))
                img_grid = torchvision.utils.make_grid(images)
                summary_writer_pretrain.add_image('Image', img_grid)
            summary_writer_pretrain.add_scalar(
                'lr', optimizer_train.param_groups[0]['lr'], epoch)
            summary_writer_pretrain.add_scalar('train_loss', train_los, epoch)
            summary_writer_pretrain.add_scalar('train_acc', train_acc, epoch)
            summary_writer_pretrain.add_scalar('test_loss', test_los, epoch)
            summary_writer_pretrain.add_scalar('test_acc', test_acc, epoch)

        summary_writer_pretrain.close()

        logger.info('Best accuracy of full model: {:.3f}'.format(
            float(best_acc_train)))

    if args.use_pretrain is False and args.train_slim is False:
        if args.resume_pretrain:
            path = args.resume_pretrain.rsplit("/", 1)[0]
            ckpt = torch.load(f'{path}/model_best.pt')
        else:
            ckpt = torch.load(f'{args.job_dir}/pretrained_model/model_best.pt')
        origin_model.load_state_dict(ckpt['state_dict'], strict=False)
        oristate_dict = origin_model.state_dict()

    # tensorboard
    logdir_finetune = Path(args.job_dir) / 'logs_finetune'
    logdir_finetune.mkdir(parents=True, exist_ok=True)
    summary_writer_finetune = SummaryWriter(logdir_finetune, flush_secs=120)

    # prune model and finetune from scratch on target-domain dataset
    if args.use_pretrain and args.transfer:
        logger.info('Transfer model from ImageNet to target-domain dataset.')
        pass
    elif args.hard_inherit and args.transfer:
        logger.info(
            'Prune the transferred model. Then it inherits the unpruned weights and is finetined on target-domain dataset.'
        )
        logger.info('Criterion for Filter Pruning: {}'.format(args.prune_rule))
        if args.prune_rule != 'NS_pretrain' and args.prune_rule != 'epruner_pretrain' and args.prune_rule != 'depgraph_pretrain':
            model = import_module('model.resnet').resnet(
                args.cfg, honey=compress_rate,
                num_classes=num_classes).to(device)

        if args.prune_rule == 'hrank_pretrain':
            load_resnet_hrank(model)
        elif args.prune_rule == 'NS_pretrain':
            honey = networkslimming.get_resnet_honey(origin_model,
                                                     args.channel_PR)
            model = import_module('model.resnet_2').resnet(
                args.cfg, honey=honey, num_classes=num_classes).to(device)
            load_resnet(model, args.prune_rule)
        elif args.prune_rule == 'epruner_pretrain':
            model, honey = epruner.cluster_resnet(oristate_dict, args.cfg,
                                                  num_classes)
            model = model.to(device)
        elif args.prune_rule == 'depgraph_pretrain':
            model = depgraph.prune_groupnorm(origin_model, input_image,
                                                    num_classes)
            model = model.to(device)
        else:
            load_resnet(model, args.prune_rule)
    # prune model and finetune with inherited weights on target-domain dataset
    elif args.use_pretrain and args.hard_inherit:
        logger.info(
            'Model is pretrained on ImageNet and pruned according to the pretrained weights. The pruned model inherits pretrained weights and is finetuned on target-domain dataset.'
        )
        logger.info('Criterion for Filter Pruning: {}'.format(args.prune_rule))
        if args.prune_rule != 'NS_pretrain' and args.prune_rule != 'epruner_pretrain' and args.prune_rule != 'depgraph_pretrain':
            model = import_module('model.resnet').resnet(
                args.cfg, honey=compress_rate,
                num_classes=num_classes).to(device)

        if args.prune_rule == 'hrank_pretrain':
            load_resnet_hrank(model)
        elif args.prune_rule == 'NS_pretrain':
            honey = networkslimming.get_resnet_honey(origin_model,
                                                     args.channel_PR)
            model = import_module('model.resnet_2').resnet(
                args.cfg, honey=honey, num_classes=num_classes).to(device)
            load_resnet(model, args.prune_rule)
        elif args.prune_rule == 'epruner_pretrain':
            model, honey = epruner.cluster_resnet(oristate_dict, args.cfg,
                                                  num_classes)
            model = model.to(device)
        elif args.prune_rule == 'depgraph_pretrain':
            model = depgraph.prune_groupnorm(origin_model, input_image,
                                                    num_classes)
            model = model.to(device)
        else:
            load_resnet(model, args.prune_rule)
        print('model:', model)
    elif args.train_slim:
        logger.info(
            'Model is structured by preset pruning rate and is trained from scratch on the target-domain dataset.'
        )
        model = import_module('model.resnet').resnet(
            args.cfg, honey=compress_rate, num_classes=num_classes).to(device)

    elif not args.use_pretrain:
        logger.info(
            'Model is pretrained on target-domain dataset and pruned according to the pretrained weights. The pruned model inherits pretrained weights and is finetuned on target-domain dataset.'
        )
        logger.info('Criterion for Filter Pruning: {}'.format(
            args.prune_rule))
        if args.prune_rule != 'NS_pretrain' and args.prune_rule != 'epruner_pretrain' and args.prune_rule != 'depgraph_pretrain':
            model = import_module('model.resnet').resnet(
                args.cfg, honey=compress_rate,
                num_classes=num_classes).to(device)

        if args.prune_rule == 'hrank_pretrain':
            load_resnet_hrank(model)
        elif args.prune_rule == 'NS_pretrain':
            honey = networkslimming.get_resnet_honey(
                origin_model, args.channel_PR)
            print('honey:', honey)
            model = import_module('model.resnet_2').resnet(
                args.cfg, honey=honey, num_classes=num_classes).to(device)
            load_resnet(model, args.prune_rule)
        elif args.prune_rule == 'epruner_pretrain':
            model, honey = epruner.cluster_resnet(oristate_dict, args.cfg,
                                                    num_classes)
            model = model.to(device)
        elif args.prune_rule == 'depgraph_pretrain':
            model = depgraph.prune_groupnorm(
                origin_model, input_image, num_classes)
            model = model.to(device)
        else:
            load_resnet(model, args.prune_rule)
        print('model:', model)
        

    lr_decay_step_finetune = list(
        map(int, args.lr_decay_step_finetune.split(',')))
    optimizer_finetune = optim.SGD(model.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)

    test(model, loader.loader_test, 0)

    pruned_flops, pruned_params = profile(model, inputs=(input_image, ))

    logger.info('pruned_flops: {}\tpruned_params: {}\n'.format(
        pruned_flops, pruned_params))

    best_acc_finetune = 0
    start_epoch = 0
    time_per_epoch = utils.AverageMeter()

    if args.resume_finetune:
        # Model
        resumeckpt = torch.load(args.resume_finetune)
        model.load_state_dict(resumeckpt['state_dict'], strict=False)
        optimizer_finetune.load_state_dict(resumeckpt['optimizer'])
        start_epoch = resumeckpt['epoch']
        best_acc_finetune = resumeckpt['best_acc']
        time_per_epoch = resumeckpt['time_per_epoch']

    # fintune on target-domain dataset
    if len(args.gpus) > 1:
        device_id = []
        for i in range((len(args.gpus) + 1) // 2):
            device_id.append(i)
        model = nn.DataParallel(model, device_ids=device_id).to(device)

    if not os.path.exists(f'{args.job_dir}/finetune/'):
        os.makedirs(f'{args.job_dir}/finetune/')

    for epoch in range(start_epoch, args.finetune_epochs):
        train_acc, train_los = train(model, optimizer_finetune,
                                     loader.loader_train, args, epoch,
                                     lr_decay_step_finetune,
                                     args.finetune_epochs)
        test_acc, test_los = test(model, loader.loader_test, best_acc_finetune)
        logger.info('Time per epoch for training: {:.2f}'.format(
            float(time_per_epoch.avg)))

        is_best = best_acc_finetune < test_acc
        best_acc_finetune = max(best_acc_finetune, test_acc)

        model_state_dict = model.module.state_dict() if len(
            args.gpus) > 1 else model.state_dict()

        state = {
            'state_dict': model_state_dict,
            'best_acc': best_acc_finetune,
            'optimizer': optimizer_finetune.state_dict(),
            'epoch': epoch + 1,
            'time_per_epoch': time_per_epoch
        }

        save_path = f'{args.job_dir}/finetune/model.pt'
        torch.save(state, save_path)
        if is_best:
            torch.save(state, f'{args.job_dir}/finetune/model_best.pt')

        # tensorboard
        if epoch == 1:
            images, labels = next(iter(loader.loader_train))
            img_grid = torchvision.utils.make_grid(images)
            summary_writer_finetune.add_image('Image', img_grid)
        summary_writer_finetune.add_scalar(
            'lr', optimizer_finetune.param_groups[0]['lr'], epoch)
        summary_writer_finetune.add_scalar('train_loss', train_los, epoch)
        summary_writer_finetune.add_scalar('train_acc', train_acc, epoch)
        summary_writer_finetune.add_scalar('test_loss', test_los, epoch)
        summary_writer_finetune.add_scalar('test_acc', test_acc, epoch)

    summary_writer_finetune.close()

    logger.info('Finetuned Model: {}'.format(model))
    logger.info('pruned_flops: {}\tpruned_params: {}\n'.format(
        pruned_flops, pruned_params))
    
    get_logger.remove_logger()


if __name__ == '__main__':
    main()