import torch
import torchvision
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import models
from torch.utils.data import DataLoader
import torch.nn.parallel
import torch.utils.data.distributed
import torch.multiprocessing as mp
import torch.distributed as dist
from utils.options import args
import utils.common as utils
from importlib import import_module
from thop import profile

from data import iNat2018
from collections import OrderedDict

import numpy as np
from pathlib import Path
import os
import math
import time
import heapq
import random
from torch.utils.tensorboard import SummaryWriter
from methods import networkslimming, epruner, depgraph
import copy

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

if args.manualSeed is not None:
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
else:
    args.manualSeed = random.randint(1, 10000)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
cudnn.deterministic = True
cudnn.benchmark = False

compress_rate_str = args.compress_rate.replace('*', ']')
compress_rate_str = compress_rate_str.replace('+', '+[')
compress_rate_str = '[' + compress_rate_str

if args.train_slim:
    args.job_dir = os.path.join(args.job_dir, 'SP^rT', 'lr' + str(args.lr),
                                compress_rate_str)
elif not args.use_pretrain:
    if args.transfer and args.hard_inherit:
        if args.prune_rule == 'NS_pretrain':
            args.job_dir = os.path.join(args.job_dir, 'STP^wT',
                                        args.prune_rule, 'lr' + str(args.lr),
                                        'PR_' + str(args.channel_PR))
        elif args.prune_rule == 'epruner_pretrain':
            args.job_dir = os.path.join(args.job_dir, 'STP^wT',
                                        args.prune_rule, 'lr' + str(args.lr),
                                        'beta_' + str(args.preference_beta))
        elif args.prune_rule == 'depgraph_pretrain':
            args.job_dir = os.path.join(
                args.job_dir, 'STP^wT', args.prune_rule, 'lr' + str(args.lr),
                'flops_PR_' + str(args.target_flops_PR))
        else:
            args.job_dir = os.path.join(args.job_dir, 'STP^wT',
                                        args.prune_rule, 'lr' + str(args.lr),
                                        compress_rate_str)
elif args.transfer:
    args.job_dir = os.path.join(args.job_dir, 'ST', 'lr' + str(args.lr))

elif args.hard_inherit:
    if args.prune_rule == 'NS_pretrain':
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr),
                                    'PR_' + str(args.channel_PR))
    elif args.prune_rule == 'epruner_pretrain':
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr),
                                    'beta_' + str(args.preference_beta))
    elif args.prune_rule == 'depgraph_pretrain':
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr),
                                    'flops_PR_' + str(args.target_flops_PR))
    else:
        args.job_dir = os.path.join(args.job_dir, 'SP^wT', args.prune_rule,
                                    'lr' + str(args.lr), compress_rate_str)

checkpoint = utils.checkpoint(args)

if not os.path.exists(args.job_dir):
    os.makedirs(args.job_dir)

get_logger = utils.get_logger()
get_logger.add_logger(os.path.join(args.job_dir, 'logger.log'))
logger = get_logger.logger


def adjust_learning_rate(optimizer, epoch, args, lr_decay_step, epochs, lr0):
    if args.lr_type == 'step':
        factor = 0
        for i in range(1, epoch + 1):
            if i in lr_decay_step:
                factor += 1

        lr = lr0 * (0.1**factor)
    elif args.lr_type == 'cos':
        lr = 0.5 * lr0 * (1 + math.cos(math.pi * epoch / epochs))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def load_resnet(model, prune_rule, init_state_dict=None):
    global oristate_dict
    cfg = {'resnet50': [3, 4, 6, 3]}
    current_cfg = cfg[args.cfg]
    block_last_name = []
    for layer, num in enumerate(current_cfg):
        block_last_name.append('layer{}.{}.conv3'.format(
            str(layer + 1), str(num - 1)))
    del block_last_name[-1]

    state_dict = model.state_dict()

    last_select_index = None
    block_select_index = None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            conv_weight_name = name + '.weight'
            oriweight = oristate_dict[conv_weight_name]
            curweight = state_dict[conv_weight_name]
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)

            if orifilter_num != currentfilter_num and (
                    prune_rule == 'random_pretrain' or prune_rule
                    == 'l1_pretrain' or prune_rule == 'bn_pretrain'
                    or prune_rule == 'NS_pretrain'):

                select_num = currentfilter_num
                if prune_rule == 'random_pretrain':
                    select_index = random.sample(range(0, orifilter_num - 1),
                                                 select_num)
                    select_index.sort()
                elif prune_rule == "l1_pretrain":
                    l1_sum = list(torch.sum(torch.abs(oriweight), [1, 2, 3]))
                    select_index = list(
                        map(l1_sum.index,
                            heapq.nlargest(currentfilter_num, l1_sum)))
                    select_index.sort()
                elif prune_rule == 'bn_pretrain':
                    if 'conv' in conv_weight_name:
                        bn_weight_name = conv_weight_name.replace('conv', 'bn')
                    elif 'downsample ' in conv_weight_name:
                        bn_weight_name = conv_weight_name.replace('0', '*', 1)
                        bn_weight_name = bn_weight_name.replace('0', '1', 1)
                        bn_weight_name = bn_weight_name.replace('*', '0', 1)
                    rank = list(
                        map(lambda x: x.item(),
                            torch.argsort(oristate_dict[bn_weight_name])))
                    select_index = rank[::-1][:select_num]
                    select_index.sort()
                elif prune_rule == 'NS_pretrain':
                    if 'conv' in conv_weight_name:
                        bn_weight_name = conv_weight_name.replace('conv', 'bn')
                    elif 'downsample' in conv_weight_name:
                        bn_weight_name = conv_weight_name.replace('0', '*', 1)
                        bn_weight_name = bn_weight_name.replace('0', '1', 1)
                        bn_weight_name = bn_weight_name.replace('*', '0', 1)
                    rank = list(
                        map(
                            lambda x: x.item(),
                            torch.argsort(
                                oristate_dict[bn_weight_name].abs().clone())))
                    select_index = rank[::-1][:select_num]
                    select_index.sort()

                if ('layer' not in conv_weight_name
                        and 'conv1' in conv_weight_name) or (
                            block_last_name[0] in conv_weight_name) or (
                                block_last_name[1]
                                in conv_weight_name) or (block_last_name[2]
                                                         in conv_weight_name):
                    block_select_index = select_index

                if 'downsample' not in conv_weight_name:
                    if last_select_index is not None:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        oristate_dict[conv_weight_name][i][j]
                        else:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        init_state_dict[conv_weight_name][i][j]
                    else:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    oristate_dict[conv_weight_name][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    init_state_dict[conv_weight_name][i]

                    last_select_index = select_index
                else:
                    if block_select_index is not None:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(
                                        block_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        oristate_dict[conv_weight_name][i][j]
                        else:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(
                                        block_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        init_state_dict[conv_weight_name][i][j]
                    else:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    oristate_dict[conv_weight_name][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    init_state_dict[conv_weight_name][i]
                    block_select_index = None

            elif 'conv' in conv_weight_name and last_select_index is not None:
                if init_state_dict is None:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                oristate_dict[conv_weight_name][index_i][j]
                else:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                init_state_dict[conv_weight_name][index_i][j]
                last_select_index = None

            elif 'downsample' in conv_weight_name and block_select_index is not None:
                if init_state_dict is None:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(block_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                oristate_dict[conv_weight_name][index_i][j]
                else:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(block_select_index):
                            state_dict[conv_weight_name][index_i][index_j] = \
                                init_state_dict[conv_weight_name][index_i][j]
                block_select_index = None

            else:
                if init_state_dict is None:
                    state_dict[conv_weight_name] = oriweight
                    last_select_index = None
                else:
                    state_dict[conv_weight_name] = init_state_dict[
                        conv_weight_name]
                    last_select_index = None
        elif isinstance(module, nn.BatchNorm2d):
            bn_weight_name = name + '.weight'
            bn_bias_name = bn_weight_name.replace('weight', 'bias')
            bn_rm_name = bn_weight_name.replace('weight', 'running_mean')
            bn_rv_name = bn_weight_name.replace('weight', 'running_var')
            oriweight = oristate_dict[bn_weight_name]
            curweight = state_dict[bn_weight_name]
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)

            if last_select_index is None:
                if init_state_dict is None:
                    state_dict[bn_weight_name] = oristate_dict[bn_weight_name]
                    state_dict[bn_bias_name] = oristate_dict[bn_bias_name]
                    state_dict[bn_rm_name] = oristate_dict[bn_rm_name]
                    state_dict[bn_rv_name] = oristate_dict[bn_rv_name]
                else:
                    state_dict[bn_weight_name] = init_state_dict[
                        bn_weight_name]
                    state_dict[bn_bias_name] = init_state_dict[bn_bias_name]
                    state_dict[bn_rm_name] = init_state_dict[bn_rm_name]
                    state_dict[bn_rv_name] = init_state_dict[bn_rv_name]
            else:
                if init_state_dict is None:
                    for index_i, i in enumerate(last_select_index):
                        state_dict[bn_weight_name][index_i] = oristate_dict[
                            bn_weight_name][i]
                        state_dict[bn_bias_name][index_i] = oristate_dict[
                            bn_bias_name][i]
                        state_dict[bn_rm_name][index_i] = oristate_dict[
                            bn_rm_name][i]
                        state_dict[bn_rv_name][index_i] = oristate_dict[
                            bn_rv_name][i]
                else:
                    for index_i, i in enumerate(last_select_index):
                        state_dict[bn_weight_name][index_i] = init_state_dict[
                            bn_weight_name][i]
                        state_dict[bn_bias_name][index_i] = init_state_dict[
                            bn_bias_name][i]
                        state_dict[bn_rm_name][index_i] = init_state_dict[
                            bn_rm_name][i]
                        state_dict[bn_rv_name][index_i] = init_state_dict[
                            bn_rv_name][i]
        elif isinstance(module, nn.Linear):
            if init_state_dict is None:
                state_dict[name + '.weight'] = oristate_dict[name + '.weight']
                state_dict[name + '.bias'] = oristate_dict[name + '.bias']
            else:
                state_dict[name + '.weight'] = init_state_dict[name +
                                                               '.weight']
                state_dict[name + '.bias'] = init_state_dict[name + '.bias']

    model.load_state_dict(state_dict)


def load_resnet_hrank(model, init_state_dict=None):
    global oristate_dict
    cfg = {'resnet50': [3, 4, 6, 3]}

    state_dict = model.state_dict()

    current_cfg = cfg[args.cfg]
    last_select_index = None

    all_honey_conv_weight = []

    bn_part_name = ['.weight', '.bias', '.running_mean', '.running_var']
    prefix = f'{args.rank_path}/rank_conv'
    subfix = ".npy"
    cnt = 1

    conv_weight_name = 'conv1.weight'
    all_honey_conv_weight.append(conv_weight_name)
    oriweight = oristate_dict[conv_weight_name]
    curweight = state_dict[conv_weight_name]
    orifilter_num = oriweight.size(0)
    currentfilter_num = curweight.size(0)

    if orifilter_num != currentfilter_num:
        logger.info('loading rank from: ' + prefix + str(cnt) + subfix)
        rank = np.load(prefix + str(cnt) + subfix)
        select_index = np.argsort(rank)[
            orifilter_num - currentfilter_num:]  # preserved filter id
        select_index.sort()

        if init_state_dict is None:
            for index_i, i in enumerate(select_index):
                state_dict[conv_weight_name][index_i] = \
                    oristate_dict[conv_weight_name][i]
                for bn_part in bn_part_name:
                    state_dict['bn1' + bn_part][index_i] = \
                        oristate_dict['bn1' + bn_part][i]
        else:
            for index_i, i in enumerate(select_index):
                state_dict[conv_weight_name][index_i] = \
                    init_state_dict[conv_weight_name][i]
                for bn_part in bn_part_name:
                    state_dict['bn1' + bn_part][index_i] = \
                        init_state_dict['bn1' + bn_part][i]

        last_select_index = select_index
    else:
        if init_state_dict is None:
            state_dict[conv_weight_name] = oriweight
            for bn_part in bn_part_name:
                state_dict['bn1' + bn_part] = oristate_dict['bn1' + bn_part]
        else:
            state_dict[conv_weight_name] = init_state_dict[conv_weight_name]
            for bn_part in bn_part_name:
                state_dict['bn1' + bn_part] = init_state_dict['bn1' + bn_part]

    if init_state_dict is None:
        state_dict['bn1' + '.num_batches_tracked'] = oristate_dict[
            'bn1' + '.num_batches_tracked']
    else:
        state_dict['bn1' + '.num_batches_tracked'] = init_state_dict[
            'bn1' + '.num_batches_tracked']

    cnt += 1
    for layer, num in enumerate(current_cfg):
        layer_name = 'layer' + str(layer + 1) + '.'

        for k in range(num):
            if args.cfg == 'resnet50':
                iter = 3
            else:
                raise ValueError("arsgs.cfg must be resnet50")
            if k == 0:
                iter += 1
            for l in range(iter):
                record_last = True
                if k == 0 and l == 2:
                    conv_name = layer_name + str(k) + '.downsample.0'
                    bn_name = layer_name + str(k) + '.downsample.1'
                    record_last = False
                elif k == 0 and l == 3:
                    conv_name = layer_name + str(k) + '.conv' + str(l)
                    bn_name = layer_name + str(k) + '.bn' + str(l)
                else:
                    conv_name = layer_name + str(k) + '.conv' + str(l + 1)
                    bn_name = layer_name + str(k) + '.bn' + str(l + 1)

                conv_weight_name = conv_name + '.weight'
                all_honey_conv_weight.append(conv_weight_name)
                oriweight = oristate_dict[conv_weight_name]
                curweight = state_dict[conv_weight_name]
                orifilter_num = oriweight.size(0)
                currentfilter_num = curweight.size(0)

                if orifilter_num != currentfilter_num:
                    logger.info('loading rank from: ' + prefix + str(cnt) +
                                subfix)
                    rank = np.load(prefix + str(cnt) + subfix)
                    select_index = np.argsort(
                        rank)[orifilter_num -
                              currentfilter_num:]  # preserved filter id
                    select_index.sort()

                    if last_select_index is not None:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        oristate_dict[conv_weight_name][i][j]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        oristate_dict[bn_name + bn_part][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                for index_j, j in enumerate(last_select_index):
                                    state_dict[conv_weight_name][index_i][index_j] = \
                                        init_state_dict[conv_weight_name][i][j]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        init_state_dict[bn_name + bn_part][i]

                    else:
                        if init_state_dict is None:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    oristate_dict[conv_weight_name][i]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        oristate_dict[bn_name + bn_part][i]
                        else:
                            for index_i, i in enumerate(select_index):
                                state_dict[conv_weight_name][index_i] = \
                                    init_state_dict[conv_weight_name][i]

                                for bn_part in bn_part_name:
                                    state_dict[bn_name + bn_part][index_i] = \
                                        init_state_dict[bn_name + bn_part][i]

                    if record_last:
                        last_select_index = select_index

                elif last_select_index is not None:
                    if init_state_dict is None:
                        for index_i in range(orifilter_num):
                            for index_j, j in enumerate(last_select_index):
                                state_dict[conv_weight_name][index_i][index_j] = \
                                    oristate_dict[conv_weight_name][index_i][j]

                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                oristate_dict[bn_name + bn_part]
                    else:
                        for index_i in range(orifilter_num):
                            for index_j, j in enumerate(last_select_index):
                                state_dict[conv_weight_name][index_i][index_j] = \
                                    init_state_dict[conv_weight_name][index_i][j]

                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                init_state_dict[bn_name + bn_part]

                    if record_last:
                        last_select_index = None

                else:
                    if init_state_dict is None:
                        state_dict[conv_weight_name] = oriweight
                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                oristate_dict[bn_name + bn_part]
                    else:
                        state_dict[conv_weight_name] = init_state_dict[
                            conv_weight_name]
                        for bn_part in bn_part_name:
                            state_dict[bn_name + bn_part] = \
                                init_state_dict[bn_name + bn_part]
                    if record_last:
                        last_select_index = None

                if init_state_dict is None:
                    state_dict[bn_name +
                               '.num_batches_tracked'] = oristate_dict[
                                   bn_name + '.num_batches_tracked']
                else:
                    state_dict[bn_name +
                               '.num_batches_tracked'] = init_state_dict[
                                   bn_name + '.num_batches_tracked']
                cnt += 1

    for name, module in model.named_modules():
        name = name.replace('module.', '')
        if isinstance(module, nn.Conv2d):
            conv_name = name + '.weight'
            if conv_name not in all_honey_conv_weight:
                if init_state_dict is None:
                    state_dict[conv_name] = oristate_dict[conv_name]
                else:
                    state_dict[conv_name] = init_state_dict[conv_name]

        elif isinstance(module, nn.Linear):
            if init_state_dict is None:
                state_dict[name + '.weight'] = oristate_dict[name + '.weight']
                state_dict[name + '.bias'] = oristate_dict[name + '.bias']
            else:
                state_dict[name + '.weight'] = init_state_dict[name +
                                                               '.weight']
                state_dict[name + '.bias'] = init_state_dict[name + '.bias']

    model.load_state_dict(state_dict)


def main(args, model, save_dir, lr, epochs, optimizer, scaler):
    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    if torch.cuda.is_available():
        ngpus_per_node = torch.cuda.device_count()

    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size

        logger.info(
            "Multiprocess distributed training, gpus:{}, total batch size:{}, epoch:{}, lr:{}"
            .format(ngpus_per_node, args.train_batch_size, epochs, lr))

        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker,
                 nprocs=ngpus_per_node,
                 args=(ngpus_per_node, args, model, save_dir, epochs, lr,
                       optimizer, scaler))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args, model, save_dir, epochs,
                    lr, optimizer, scaler)


def main_worker(gpu, ngpus_per_node, args, model, save_dir, epochs, lr,
                optimizer, scaler):
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed and args.multiprocessing_distributed:
        # For multiprocessing distributed training, rank needs to be the
        # global rank among all the processes
        args.local_rank = args.local_rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.local_rank)

    if args.local_rank % ngpus_per_node == 0:
        # 权重
        save_path = Path(save_dir)
        weights = save_path / 'weights'
        weights.mkdir(parents=True, exist_ok=True)

        # acc,loss
        acc_loss = save_path / 'acc_loss'
        acc_loss.mkdir(parents=True, exist_ok=True)
        train_acc_top1_savepath = acc_loss / 'train_acc_top1.npy'
        train_acc_top5_savepath = acc_loss / 'train_acc_top5.npy'
        train_loss_savepath = acc_loss / 'train_loss.npy'
        val_acc_top1_savepath = acc_loss / 'val_acc_top1.npy'
        val_acc_top5_savepath = acc_loss / 'val_acc_top5.npy'
        val_loss_savepath = acc_loss / 'val_loss.npy'

        # tensorboard
        logdir = save_path / 'logs'
        logdir.mkdir(parents=True, exist_ok=True)
        summary_writer = SummaryWriter(logdir, flush_secs=120)

    # dataset
    train_dataset, val_dataset, _ = iNat2018()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if torch.cuda.is_available():
            if args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                model.cuda(args.gpu)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs of the current node.
                args.train_batch_size = int(args.train_batch_size /
                                            ngpus_per_node)
                args.workers = int(
                    (args.workers + ngpus_per_node - 1) / ngpus_per_node)
                model = torch.nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[args.gpu])  # , find_unused_parameters=True
            else:
                model.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)

    # loss
    if torch.cuda.is_available():
        if args.gpu:
            device = torch.device('cuda:{}'.format(args.gpu))
        else:
            device = torch.device("cuda")

    criterion = nn.CrossEntropyLoss().to(device)

    if 'finetune' in save_dir:
        resume = args.resume_finetune
        lr_decay_step = args.lr_decay_step_finetune
    else:
        resume = args.resume_pretrain
        lr_decay_step = args.lr_decay_step_pretrain

    if resume:
        if args.gpu is None:
            checkpoint = torch.load(resume)
        elif torch.cuda.is_available():
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(resume, map_location=loc)

        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler'])
        best_acc1 = torch.tensor(checkpoint['best_acc1'])
        best_acc5 = torch.tensor(checkpoint['best_acc5'])
        if args.gpu is not None:
            # best_acc may be from a checkpoint from a different GPU
            best_acc1 = best_acc1.to(args.gpu)
            best_acc5 = best_acc5.to(args.gpu)

        train_acc_top1 = checkpoint['train_acc_top1']
        train_acc_top5 = checkpoint['train_acc_top5']
        train_loss = checkpoint['train_loss']
        test_acc_top1 = checkpoint['test_acc_top1']
        test_acc_top5 = checkpoint['test_acc_top5']
        test_loss = checkpoint['test_loss']
        if args.local_rank % ngpus_per_node == 0:
            logger.info('Resuming training from {} epoch'.format(start_epoch))
    else:
        start_epoch = 0
        best_acc1 = 0
        best_acc5 = 0
        train_acc_top1 = []
        train_acc_top5 = []
        train_loss = []
        test_acc_top1 = []
        test_acc_top5 = []
        test_loss = []

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, shuffle=False)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch_size,
                              shuffle=(train_sampler is None),
                              num_workers=args.workers,
                              pin_memory=True,
                              sampler=train_sampler,
                              drop_last=True)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.eval_batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            sampler=val_sampler,
                            drop_last=False)

    for epoch in range(start_epoch, epochs):
        if args.local_rank % ngpus_per_node == 0:
            logger.info("Epoch {}/{}".format(epoch, epochs))
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_epoch_loss, train_acc1, train_acc5 = train(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            scaler=scaler,
            criterion=criterion,
            ngpus_per_node=ngpus_per_node,
            args=args,
            epoch=epoch,
            lr_decay_step=lr_decay_step,
            epochs=epochs,
            lr0=lr)

        val_epoch_loss, val_acc1, val_acc5 = validate(model=model,
                                                      val_loader=val_loader,
                                                      criterion=criterion,
                                                      args=args)
        if args.local_rank % ngpus_per_node == 0:
            logger.info(
                'Test Loss {:.4f}\t Test Top1 {:.3f}% / {:.3f}%\tTest Top5 {:.3f}% / {:.3f}%\tLearning Rate {:.6f}\n'
                .format(float(val_epoch_loss),
                        float(val_acc1), float(best_acc1), float(val_acc5),
                        float(best_acc5), optimizer.param_groups[0]['lr']))

            train_loss.append(train_epoch_loss)
            train_acc_top1.append(train_acc1)
            train_acc_top5.append(train_acc5)
            test_loss.append(val_epoch_loss)
            test_acc_top1.append(val_acc1)
            test_acc_top5.append(val_acc5)

            # save model
            is_best = val_acc1 > best_acc1
            if is_best:
                best_acc1 = val_acc1
                best_acc5 = val_acc5

            state = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc1': best_acc1,
                'best_acc5': best_acc5,
                'train_acc_top1': train_acc_top1,
                'train_acc_top5': train_acc_top5,
                'train_loss': train_loss,
                'test_acc_top1': test_acc_top1,
                'test_acc_top5': test_acc_top5,
                'test_loss': test_loss,
                'scaler': scaler,
            }

            torch.save(state, os.path.join(weights, 'model.pt'))
            if is_best:
                torch.save(state, os.path.join(weights, 'model_best.pt'))

            if epoch == 1:
                images, labels = next(iter(train_loader))
                img_grid = torchvision.utils.make_grid(images)
                summary_writer.add_image('Image', img_grid)

            summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'],
                                      epoch)
            summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
            summary_writer.add_scalar('train_acc_top1', train_acc1, epoch)
            summary_writer.add_scalar('train_acc_top5', train_acc5, epoch)
            summary_writer.add_scalar('val_loss', val_epoch_loss, epoch)
            summary_writer.add_scalar('val_acc_top1', val_acc1, epoch)
            summary_writer.add_scalar('val_acc_top5', val_acc5, epoch)

    if args.local_rank % ngpus_per_node == 0:
        summary_writer.close()
        if not os.path.exists(train_acc_top1_savepath) or not os.path.exists(
                train_loss_savepath):
            np.save(train_acc_top1_savepath, train_acc_top1)
            np.save(train_acc_top5_savepath, train_acc_top5)
            np.save(train_loss_savepath, train_loss)
            np.save(val_acc_top1_savepath, test_acc_top1)
            np.save(val_acc_top5_savepath, test_acc_top5)
            np.save(val_loss_savepath, test_loss)


def train(model, train_loader, optimizer, scaler, criterion, ngpus_per_node,
          args, epoch, lr_decay_step, epochs, lr0):
    train_loss = utils.AverageMeter()
    train_acc1 = utils.AverageMeter()
    train_acc5 = utils.AverageMeter()

    # Model on train mode
    model.train()
    step_per_epoch = len(train_loader)
    print_freq = step_per_epoch // 10
    for step, (images, labels) in enumerate(train_loader):
        start = time.time()
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, epoch, args, lr_decay_step, epochs,
                             lr0)

        if args.gpu is not None and torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            labels = labels.cuda(args.gpu, non_blocking=True)

        with torch.cuda.amp.autocast():
            # compute output
            logits = model(images)
            # loss
            loss = criterion(logits, labels)

        # compute gradient and do SGD step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # measure accuracy and record loss
        acc1, acc5 = utils.accuracy(logits, labels, topk=(1, 5))

        train_loss.update(loss.item(), images.size(0))
        train_acc1.update(acc1[0].item(), images.size(0))
        train_acc5.update(acc5[0].item(), images.size(0))

        if args.local_rank % ngpus_per_node == 0:
            if step % print_freq == 0 and step != 0:
                logger.info('Epoch[{}] ({}/{}):\t'
                            'Learning Rate {:.6f}\t'
                            'Loss {:.4f}\t'
                            'Top1 {:.3f}%\t'
                            'Time {:.2f}ms/step'.format(
                                epoch, step, step_per_epoch,
                                float(optimizer.param_groups[0]['lr']),
                                float(train_loss.avg), float(train_acc1.avg),
                                1000 * (time.time() - start)))

    if args.local_rank % ngpus_per_node == 0:
        print()
    return train_loss.avg, train_acc1.avg, train_acc5.avg


def validate(model, val_loader, criterion, args):
    val_loss = utils.AverageMeter()
    val_acc1 = utils.AverageMeter()
    val_acc5 = utils.AverageMeter()

    # model to evaluate mode
    model.eval()
    with torch.no_grad():
        for step, (images, labels) in enumerate(val_loader):
            if args.gpu is not None and torch.cuda.is_available():
                images = images.cuda(args.gpu, non_blocking=True)
                labels = labels.cuda(args.gpu, non_blocking=True)

            # compute output
            logits = model(images)
            loss = criterion(logits, labels)

            # measure accuracy and record loss
            acc1, acc5 = utils.accuracy(logits, labels, topk=(1, 5))

            # Average loss and accuracy across processes
            if args.distributed:
                loss = reduce_tensor(loss, args)
                acc1 = reduce_tensor(acc1, args)
                acc5 = reduce_tensor(acc5, args)

            val_loss.update(loss.item(), images.size(0))
            val_acc1.update(acc1[0].item(), images.size(0))
            val_acc5.update(acc5[0].item(), images.size(0))

    return val_loss.avg, val_acc1.avg, val_acc5.avg


def reduce_tensor(tensor, args):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= args.world_size
    return rt


def testmodel(model, test_data, args):
    val_acc1 = utils.AverageMeter()
    val_acc5 = utils.AverageMeter()

    # model to evaluate mode
    model.eval()

    test_dataloader = DataLoader(test_data,
                                 batch_size=args.eval_batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    with torch.no_grad():
        for step, (images, labels) in enumerate(test_dataloader):
            images, labels = images.cuda(), labels.cuda()
            # compute output
            logits = model(images)

            # measure accuracy and record loss
            acc1, acc5 = utils.accuracy(logits, labels, topk=(1, 5))

            val_acc1.update(acc1[0], images.size(0))
            val_acc5.update(acc5[0], images.size(0))

    return val_acc1.avg, val_acc5.avg


if __name__ == "__main__":
    global oristate_dict

    # Model
    if args.use_pretrain:
        print('==> Loading Pretrained Model..')
        origin_model = models.resnet50(pretrained=True)
    else:
        print('==> Loading Model..')
        origin_model = models.resnet50(pretrained=False)

    # Data
    print('==> Loading Data..')
    if args.data_set == 'inaturalist2018':
        train_data, test_data, _ = iNat2018()
        input_image_size = 224
        num_classes = 8142
        num_fc_in = origin_model.fc.in_features
        origin_model.fc = nn.Linear(num_fc_in, num_classes)
    else:
        raise AssertionError

    # calculate the parameters and FLOPs of original full-size model
    input_image = torch.randn(1, 3, input_image_size, input_image_size)
    full_flops, full_params = profile(origin_model, inputs=(input_image, ))
    logger.info('full_flops: {}\tfull_params: {}\n'.format(
        full_flops, full_params))

    model = copy.deepcopy(origin_model)
    logger.info('Full model: {}'.format(model))

    # compression rate
    if args.compress_rate:
        import re
        cprate_str = args.compress_rate
        cprate_str_list = cprate_str.split('+')
        pat_cprate = re.compile(r'\d+\.\d*')
        pat_num = re.compile(r'\*\d+')
        cprate = []
        for x in cprate_str_list:
            num = 1
            find_num = re.findall(pat_num, x)
            if find_num:
                assert len(find_num) == 1
                num = int(find_num[0].replace('*', ''))
            find_cprate = re.findall(pat_cprate, x)
            assert len(find_cprate) == 1
            cprate += [float(find_cprate[0]) * 10] * num

        compress_rate = cprate
        print('compress_rate:', compress_rate)

    if args.use_pretrain is False:
        if args.resume_pretrain:
            ckpt = torch.load(args.resume_pretrain)['model_state_dict']
            print('Load pretrained model.')

            new_state_dict = OrderedDict()

            for k, v in ckpt.items():
                name = k[7:]  # remove 'module.'
                new_state_dict[name] = v

            origin_model.load_state_dict(new_state_dict, strict=False)
        else:
            raise AssertionError

    oristate_dict = origin_model.state_dict()
    print('Successful loading.')

    # prune model and finetune from scratch on target-domain dataset
    if args.use_pretrain and args.transfer:
        logger.info('Transfer model from ImageNet to target-domain dataset.')
        pass
    elif args.hard_inherit and args.transfer:
        logger.info(
            'Prune the transferred model. Then it inherits the unpruned weights and is finetined on target-domain dataset.'
        )
        logger.info('Criterion for Filter Pruning: {}'.format(args.prune_rule))
        if args.prune_rule != 'NS_pretrain' and args.prune_rule != 'epruner_pretrain' and args.prune_rule != 'depgraph_pretrain':
            model = import_module('model.resnet').resnet(
                args.cfg, honey=compress_rate, num_classes=num_classes)

        if args.prune_rule == 'hrank_pretrain':
            load_resnet_hrank(model)
        elif args.prune_rule == 'NS_pretrain':
            honey = networkslimming.get_resnet_honey(origin_model,
                                                     args.channel_PR)
            model = import_module('model.resnet_2').resnet(
                args.cfg, honey=honey, num_classes=num_classes)
            load_resnet(model, args.prune_rule)
        elif args.prune_rule == 'epruner_pretrain':
            model, honey = epruner.cluster_resnet(oristate_dict, args.cfg,
                                                  num_classes)
            model = model
        elif args.prune_rule == 'depgraph_pretrain':
            model = depgraph.prune_groupnorm(origin_model.cuda(),
                                             input_image.cuda(), num_classes)
            model = model
        else:
            load_resnet(model, args.prune_rule)
    # prune model and finetune with inherited weights on target-domain dataset
    elif args.use_pretrain and args.hard_inherit:
        logger.info(
            'Model is pretrained on ImageNet and pruned according to the pretrained weights. The pruned model inherits pretrained weights and is finetuned on target-domain dataset.'
        )
        logger.info('Criterion for Filter Pruning: {}'.format(args.prune_rule))
        if args.prune_rule != 'NS_pretrain' and args.prune_rule != 'epruner_pretrain' and args.prune_rule != 'depgraph_pretrain':
            model = import_module('model.resnet').resnet(
                args.cfg, honey=compress_rate, num_classes=num_classes)

        if args.prune_rule == 'hrank_pretrain':
            load_resnet_hrank(model)
        elif args.prune_rule == 'NS_pretrain':
            honey = networkslimming.get_resnet_honey(origin_model,
                                                     args.channel_PR)
            model = import_module('model.resnet_2').resnet(
                args.cfg, honey=honey, num_classes=num_classes)
            load_resnet(model, args.prune_rule)
        elif args.prune_rule == 'epruner_pretrain':
            model, honey = epruner.cluster_resnet(oristate_dict, args.cfg,
                                                  num_classes)
            model = model
        elif args.prune_rule == 'depgraph_pretrain':
            model = depgraph.prune_groupnorm(origin_model, input_image,
                                             num_classes)
            model = model
        else:
            load_resnet(model, args.prune_rule)
        print('model:', model)
    elif args.train_slim:
        logger.info(
            'Model is structured by preset pruning rate and is trained from scratch on the target-domain dataset.'
        )
        model = import_module('model.resnet').resnet(args.cfg,
                                                     honey=compress_rate,
                                                     num_classes=num_classes)

    pruned_flops, pruned_params = profile(model.cuda(),
                                          inputs=(input_image.cuda(), ))
    logger.info('Pruned model: {}'.format(model))

    model_copy = copy.deepcopy(model)
    model = model.cuda()

    logger.info('pruned_flops: {}\tpruned_params: {}\n'.format(
        pruned_flops, pruned_params))

    args.lr_decay_step_finetune = list(
        map(int, args.lr_decay_step_finetune.split(',')))
    optimizer_finetune = optim.SGD(model.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   nesterov=True,
                                   weight_decay=args.weight_decay)
    scaler_finetune = torch.cuda.amp.GradScaler(enabled=True)

    acc1, acc5 = testmodel(model=model, test_data=test_data, args=args)
    logger.info('Acc Top1: {:.3f}, Acc Top5: {:.3f}'.format(acc1, acc5))

    save_dir = f'{args.job_dir}/finetune/'
    print('save_dir:', save_dir)
    main(args, model, save_dir, args.lr, args.finetune_epochs,
         optimizer_finetune, scaler_finetune)

    logger.info('Pruned model: {}'.format(model))
    logger.info('pruned_flops: {}\tpruned_params: {}\n'.format(
        pruned_flops, pruned_params))
    
    get_logger.remove_logger()
