import os
import copy
import random
import numpy as np
import logging
import torch_pruning as tp
import engine.utils as utils


def low_level_prune(model, example_inputs, pruning_idxs, prune_layer_name, prune_func):
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=example_inputs)
    for name, m in model.named_modules():
        if name == prune_layer_name:
            cur_group = m
    pruning_group = DG.get_pruning_group(cur_group, prune_func, idxs=pruning_idxs)
    pruning_group.prune()


class FPVE:
    def __init__(
        self,
        model,
        train_loader,
        test_loader,
        example_inputs,
        criterion,
        args,
        FPVE_fitness_loader=None,
    ):
        # Dataloaders
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.example_inputs = example_inputs
        self.model = model
        self.criterion = criterion

        if FPVE_fitness_loader is not None:
            self.use_FPVE_fitness_loader = True
            self.FPVE_fitness_loader = FPVE_fitness_loader
        else:
            self.use_FPVE_fitness_loader = False

        self.gpu = args.gpu
        self.arch = args.arch
        self.pop_size = args.pop_size
        self.pop_init_rate = args.pop_init_rate
        self.pruning_ratio = args.pruning_ratio
        self.mutation_rate = args.mutation_rate
        self.evolution_epoch = args.evolution_epoch
        self.fitness_mode = args.fitness_mode
        self.iterative_steps = args.iterative_steps
        self.finetune = args.finetune
        self.initial_fitness = None
        self.finetune_config = (
            {
                "epochs": args.ft_epochs,
                "lr": args.lr,
                "lr_step_size": args.lr_step_size,
                "lr_warmup_epochs": args.lr_warmup_epochs,
                "train_loader": train_loader,
                "test_loader": test_loader,
                "criterion": criterion,
                "save_dir": args.ckpt_save_dir,
                "device": args.gpu,
                "args": args,
                "pruner": None,
                "lr_decay_milestones": args.lr_decay_milestones,
                "save_every": args.save_every,
                "return_best": False,
                "fairness_eval_flag": args.fairness_eval_flag,
                "fairness_type": args.fairness_type,
                "scalar": args.scalar,
            }
            if args.finetune
            else None
        )

        self.initial_acc_top1_on_test = None
        self.initial_per_class_acc_on_test = None

        self.fairness_eval_flag = args.fairness_eval_flag
        self.fairness_type = args.fairness_type
        self.initial_acc_top1 = None
        self.initial_per_class_acc = None
        self.scalar = args.scalar

    def fitness_with_fairness(self, test_model):
        if self.use_FPVE_fitness_loader:
            cur_dataloader = self.FPVE_fitness_loader
        else:
            cur_dataloader = self.train_loader

        cur_acc_top1, cur_acc_top5, cur_loss, per_class_acc = (
            utils.training.fairness_eval(test_model, cur_dataloader, self.gpu)
        )
        if self.initial_per_class_acc is None:
            self.initial_per_class_acc = per_class_acc
            self.initial_acc_top1 = cur_acc_top1
        fairness_score = utils.training.calculate_fairness(
            self.initial_per_class_acc, per_class_acc, fairness_type=self.fairness_type
        )

        if self.fitness_mode == "ACC_FAIRNESS":
            acc_change_ratio = (
                cur_acc_top1 - self.initial_acc_top1
            ) / self.initial_acc_top1
            fitness = acc_change_ratio + self.scalar * fairness_score
        elif self.fitness_mode == "ACC_TOP1":
            fitness = cur_acc_top1
        elif self.fitness_mode == "ACC_TOP5":
            fitness = cur_acc_top5
        else:
            raise ValueError(f"Invalid fitness mode: {self.fitness_mode}")

        return (
            fitness,
            cur_acc_top1,
            cur_acc_top5,
            cur_loss,
            per_class_acc,
            fairness_score,
        )

    def fitness(self, test_model):
        if self.use_FPVE_fitness_loader:
            cur_dataloader = self.FPVE_fitness_loader
        else:
            cur_dataloader = self.train_loader

        cur_acc_top1, cur_acc_top5, cur_loss = utils.training.eval(
            test_model, cur_dataloader, self.gpu
        )
        if self.fitness_mode == "ACC_TOP1":
            fitness = cur_acc_top1
        elif self.fitness_mode == "ACC_TOP5":
            fitness = cur_acc_top5
        else:
            raise ValueError(
                f"Invalid fitness mode: {self.fitness_mode}, if you want to use fairness, please set --fitness_mode ACC_FAIRNESS and --fairness_eval_flag True and set --fairness_type"
            )
        return fitness, cur_acc_top1, cur_acc_top5, cur_loss

    def generate_initial_pop(
        self,
        pop_size,
        pop_init_rate,
        filter_num,
        add_unpruned_indiv=False,
        target_filter_num=None,
    ):
        #  Generate initial population for a sub-EA.
        p = []
        indiv = [i for i in range(filter_num)]
        cnt = pop_size
        if add_unpruned_indiv:
            p.append(indiv)
            cnt = cnt - 1
        for _ in range(cnt):
            new_indiv = self.mutation(
                indiv,
                filter_num,
                mutation_rate=pop_init_rate,
                target_count=target_filter_num,
            )
            p.append(new_indiv)
        p.sort()
        return p

    def mutation(
        self, indiv, filter_num, mutation_rate=0.1, mode="shuffle", target_count=None
    ):
        if mode == "shuffle":
            temp_np = [[i, 0] for i in range(int(filter_num))]
            for idx in indiv:
                temp_np[idx][1] = 1

            random.shuffle(temp_np)

            current_count = len(indiv)

            for i in range(filter_num):
                if random.random() < mutation_rate:
                    temp_np[i][1] = 1 - temp_np[i][1]
                    current_count += 1 if temp_np[i][1] == 1 else -1

            if current_count != target_count:
                indices = [i for i, v in enumerate(temp_np) if v[1] == 1]
                random.shuffle(indices)
                if current_count > target_count:
                    for idx in indices[: current_count - target_count]:
                        temp_np[idx][1] = 0
                else:
                    not_selected_indices = [
                        i for i, v in enumerate(temp_np) if v[1] == 0
                    ]
                    random.shuffle(not_selected_indices)
                    for idx in not_selected_indices[: target_count - current_count]:
                        temp_np[idx][1] = 1

            new_indiv = [item[0] for item in temp_np if item[1] == 1]
            new_indiv.sort()
            return new_indiv if new_indiv else indiv

    def evolution_step(self, group_name, filter_num, target_filter_num=None):
        pop = self.generate_initial_pop(
            self.pop_size,
            self.pop_init_rate,
            filter_num,
            target_filter_num=target_filter_num,
        )
        logger = logging.getLogger("train_logger")
        logger.info("Group:{0} | filter_num:{1}".format(group_name, filter_num))
        parent_fitness = []
        logger.info(f"father population init:")
        prune_func = tp.prune_conv_out_channels

        for i in range(self.pop_size):
            test_model = copy.deepcopy(self.model)
            pruning_idxs = [_ for _ in range(filter_num) if _ not in pop[i]]
            low_level_prune(
                model=test_model,
                example_inputs=self.example_inputs,
                pruning_idxs=pruning_idxs,
                prune_layer_name=group_name,
                prune_func=prune_func,
            )

            if self.fairness_eval_flag:
                (
                    fitness_i,
                    cur_acc_top1,
                    cur_acc_top5,
                    cur_loss,
                    per_class_acc,
                    fairness_score,
                ) = self.fitness_with_fairness(test_model=test_model)
                logger.info(f"Current fa{i} top-1 accuracy: {cur_acc_top1}")
                logger.info(f"Current fa{i} top-5 accuracy: {cur_acc_top5}")
                logger.info(f"Current fa{i} loss: {cur_loss}")
                logger.info(f"Current fa{i} fitness: {fitness_i}")
                logger.info(f"Current fa{i} Fairness score: {fairness_score}")
                logger.info(f"Current fa{i} per-class accuracy: {per_class_acc}")
            else:
                fitness_i, cur_acc_top1, cur_acc_top5, cur_loss = self.fitness(
                    test_model=test_model
                )
                logger.info(f"Current fa{i} top-1 accuracy: {cur_acc_top1}")
                logger.info(f"Current fa{i} top-5 accuracy: {cur_acc_top5}")
                logger.info(f"Current fa{i} loss: {cur_loss}")
                logger.info(f"Current fa{i} fitness: {fitness_i}")

            parent_fitness.append([i, fitness_i, pop[i], len(pop[i])])

            logger.info(
                [
                    i,
                    fitness_i,
                    [_ for _ in range(filter_num) if _ not in pop[i]],
                    len(pop[i]),
                ]
            )
            logger.info("\n")

        logger.info(f"population evolution:")
        for i in range(self.evolution_epoch):
            child_fitness = []
            for j in range(self.pop_size):
                parent = pop[random.randint(0, self.pop_size - 1)]
                child_indiv = self.mutation(
                    parent,
                    filter_num,
                    self.mutation_rate,
                    target_count=target_filter_num,
                )
                test_model = copy.deepcopy(self.model)

                pruning_idxs = [_ for _ in range(filter_num) if _ not in child_indiv]
                low_level_prune(
                    model=test_model,
                    example_inputs=self.example_inputs,
                    pruning_idxs=pruning_idxs,
                    prune_layer_name=group_name,
                    prune_func=prune_func,
                )

                if self.fairness_eval_flag:
                    (
                        fitness_j,
                        cur_acc_top1,
                        cur_acc_top5,
                        cur_loss,
                        per_class_acc,
                        fairness_score,
                    ) = self.fitness_with_fairness(test_model=test_model)
                else:
                    fitness_j, cur_acc_top1, cur_acc_top5, cur_loss = self.fitness(
                        test_model=test_model
                    )
                child_fitness.append([j, fitness_j, child_indiv, len(child_indiv)])

            temp_list = []
            for j in range(len(parent_fitness)):
                temp_list.append(parent_fitness[j])
            for j in range(len(child_fitness)):
                temp_list.append(child_fitness[j])

            temp_list.sort(key=lambda x: (x[1], -x[3]), reverse=True)
            logger.info(f"Population at evolution epoch {i+1}/{self.evolution_epoch}:")
            for j in range(self.pop_size):
                pop[j] = temp_list[j][2]
                parent_fitness[j] = temp_list[j]
                logger.info(
                    [
                        parent_fitness[j][0],
                        parent_fitness[j][1],
                        [_ for _ in range(filter_num) if _ not in parent_fitness[j][2]],
                        len(parent_fitness[j][2]),
                    ]
                )

            best_ind = parent_fitness[0]
            logger.info(
                f"Best fitness so far {best_ind[1]}, Initial fitness: {self.initial_fitness}, Filter now: {best_ind[3]}, Pruning ratio: {1 - best_ind[3] / filter_num}"
            )
            logger.info("\n")

        return best_ind[2]

    def run(self):
        logger = logging.getLogger("train_logger")
        logger2 = logging.getLogger("result_logger")
        self.model.eval()
        ori_macs, ori_nparams = tp.utils.count_ops_and_params(
            self.model, example_inputs=self.example_inputs
        )
        logger.info(
            "Original Model Before Pruning | Ori_macs: {} | Ori_noarams: {}".format(
                ori_macs, ori_nparams
            )
        )
        logger.info(self.model)
        logger.info("\n")
        base_dir = self.finetune_config["save_dir"]

        if self.arch in ["resnet18", "resnet34", "resnet50", "resnet56", "resnet110"]:
            if self.arch == "resnet18":
                BLOCK_NUM = [2, 2, 2, 2]
            elif self.arch in ["resnet34", "resnet50"]:
                BLOCK_NUM = [3, 4, 6, 3]
            elif self.arch == "resnet56":
                BLOCK_NUM = [9, 9, 9]
            elif self.arch == "resnet110":
                BLOCK_NUM = [18, 18, 18]

            if self.arch in ["resnet18", "resnet34", "resnet56", "resnet110"]:
                LAYERS_P2 = [f"layer{i+1}.0.conv2" for i in range(len(BLOCK_NUM))]
                LAYERS_P1 = [
                    f"layer{i+1}.{j}.conv1"
                    for i in range(len(BLOCK_NUM))
                    for j in range(BLOCK_NUM[i])
                ]
            elif self.arch == "resnet50":
                LAYERS_P2 = ["conv1"] + [
                    f"layer{i+1}.0.conv3" for i in range(len(BLOCK_NUM))
                ]
                LAYERS_P1 = [
                    f"layer{i+1}.{j}.conv{k}"
                    for i in range(len(BLOCK_NUM))
                    for j in range(BLOCK_NUM[i])
                    for k in range(1, 3)
                ]
                # LAYERS_P3 = ['conv1']

            if self.arch in ["resnet18", "resnet34"]:
                FILTER_NUM_P2 = [64, 128, 256, 512]
                filter_num_p2 = FILTER_NUM_P2.copy()
            elif self.arch == "resnet50":
                FILTER_NUM_P2 = [64] + [256, 512, 1024, 2048]
                filter_num_p2 = FILTER_NUM_P2.copy()
                # FILTER_NUM_P3 = [64]
                # filter_num_p3 = FILTER_NUM_P3.copy()
            elif self.arch in ["resnet56", "resnet110"]:
                FILTER_NUM_P2 = [16, 32, 64]
                filter_num_p2 = FILTER_NUM_P2.copy()

            sol_p2 = [[] for _ in range(len(LAYERS_P2))]
            sol_p1 = [[] for _ in range(len(LAYERS_P1))]

            if self.arch in ["resnet18", "resnet34"]:
                FILTER_NUM_P1 = [
                    64 * (2**layer)
                    for layer in range(len(BLOCK_NUM))
                    for _ in range(BLOCK_NUM[layer])
                ]
                filter_num_p1 = FILTER_NUM_P1.copy()

            elif self.arch in ["resnet56", "resnet110"]:
                FILTER_NUM_P1 = [
                    16 * (2**layer)
                    for layer in range(len(BLOCK_NUM))
                    for _ in range(BLOCK_NUM[layer])
                ]
                filter_num_p1 = FILTER_NUM_P1.copy()
            elif self.arch == "resnet50":
                # sol_p1 = []
                FILTER_NUM_P1 = []
                for layer in range(len(BLOCK_NUM)):
                    for _ in range(BLOCK_NUM[layer]):
                        FILTER_NUM_P1.append(64 * (2**layer))
                        FILTER_NUM_P1.append(64 * (2**layer))
                        # sol_p1.append([])
                        # sol_p1.append([])
                filter_num_p1 = FILTER_NUM_P1.copy()
        else:
            raise ValueError(f"Invalid arch: {self.arch}, please check the arch")


        ALL_LAYERS_NAME = LAYERS_P1 + LAYERS_P2
        ORI_FLITER_NUM = FILTER_NUM_P1 + FILTER_NUM_P2
        all_filter_num = filter_num_p1 + filter_num_p2
        all_sol = sol_p1 + sol_p2

        for i in range(self.iterative_steps):
            logger.info(f"Iterative_steps:{i+1}")

            if self.fairness_eval_flag:
                test_acc_top1, test_acc_top5, test_loss, per_class_acc = (
                    utils.training.fairness_eval(
                        self.model,
                        self.test_loader,
                        device=self.finetune_config["device"],
                    )
                )
                if self.initial_per_class_acc_on_test is None:
                    self.initial_per_class_acc_on_test = per_class_acc
                    self.initial_acc_top1_on_test = test_acc_top1
                fairness_score = utils.training.calculate_fairness(
                    self.initial_per_class_acc_on_test,
                    per_class_acc,
                    fairness_type=self.fairness_type,
                )

                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Initial test loss: {test_loss}")
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test fairness score: {fairness_score}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test per-class accuracy: {per_class_acc}"
                )
            else:
                test_acc_top1, test_acc_top5, test_loss = utils.training.eval(
                    self.model,
                    test_loader=self.test_loader,
                    device=self.finetune_config["device"],
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Initial test loss: {test_loss}")

            if self.fairness_eval_flag:
                (
                    initial_fitness,
                    initial_acc_top1,
                    initial_acc_top5,
                    initial_loss,
                    initial_per_class_acc,
                    initial_fairness_score,
                ) = self.fitness_with_fairness(test_model=self.model)
                self.initial_fitness = initial_fitness
                # before pruning
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data top-1 accuracy: {initial_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data top-5 accuracy: {initial_acc_top5}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data loss: {initial_loss}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data fitness: {initial_fitness}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data fairness score: {initial_fairness_score}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data per-class accuracy: {initial_per_class_acc}"
                )

            else:
                initial_fitness, initial_acc_top1, initial_acc_top5, initial_loss = (
                    self.fitness(test_model=self.model)
                )
                self.initial_fitness = initial_fitness

                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data top-1 accuracy: {initial_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data top-5 accuracy: {initial_acc_top5}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data loss: {initial_loss}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial fitness_data fitness: {initial_fitness}"
                )

            logger.info(f"Iterative_steps:{i+1}, Start Evolution")
            logger.info(f"We have {len(ALL_LAYERS_NAME)} evolution groups")
            decrement_factor = (i + 1) * self.pruning_ratio / self.iterative_steps

            for idx in range(len(ALL_LAYERS_NAME)):
                target_num = int(ORI_FLITER_NUM[idx] * (1 - decrement_factor))
                all_sol[idx] = self.evolution_step(
                    group_name=ALL_LAYERS_NAME[idx],
                    filter_num=all_filter_num[idx],
                    target_filter_num=target_num,
                )

            logger.info("Start Pruning")
            for idx in range(len(ALL_LAYERS_NAME)):
                prune_layer_name = ALL_LAYERS_NAME[idx]
                logger.info(idx)
                logger.info(prune_layer_name)
                prune_func = tp.prune_conv_out_channels
                pruning_idxs = [
                    _ for _ in range(all_filter_num[idx]) if _ not in all_sol[idx]
                ]
                low_level_prune(
                    model=self.model,
                    example_inputs=self.example_inputs,
                    pruning_idxs=pruning_idxs,
                    prune_layer_name=prune_layer_name,
                    prune_func=prune_func,
                )
                all_filter_num[idx] = len(all_sol[idx])

            ###################################################################
            macs, nparams = tp.utils.count_ops_and_params(
                self.model, self.example_inputs
            )
            logger.info(
                "  Iter %d/%d, Params: %.2f M => %.2f M (%.2f%%)"
                % (
                    i + 1,
                    self.iterative_steps,
                    ori_nparams / 1e6,
                    nparams / 1e6,
                    (nparams / ori_nparams) * 100,
                )
            )
            logger.info(
                "  Iter %d/%d, MACs: %.2f G => %.2f G (%.2f%%)"
                % (
                    i + 1,
                    self.iterative_steps,
                    ori_macs / 1e9,
                    macs / 1e9,
                    (macs / ori_macs) * 100,
                )
            )
            logger.info(self.model)

            logger2.info(
                "  Iter %d/%d, Params: %.2f M => %.2f M (%.2f%%)"
                % (
                    i + 1,
                    self.iterative_steps,
                    ori_nparams / 1e6,
                    nparams / 1e6,
                    (nparams / ori_nparams) * 100,
                )
            )
            logger2.info(
                "  Iter %d/%d, MACs: %.2f G => %.2f G (%.2f%%)"
                % (
                    i + 1,
                    self.iterative_steps,
                    ori_macs / 1e9,
                    macs / 1e9,
                    (macs / ori_macs) * 100,
                )
            )
            logger2.info(self.model)
            logger.info(f"filter_num_p1 {filter_num_p1}")
            logger.info(f"filter_num_p2 {filter_num_p2}")

            # after pruning

            if self.fairness_eval_flag:
                test_acc_top1, test_acc_top5, test_loss, per_class_acc = (
                    utils.training.fairness_eval(
                        self.model,
                        self.test_loader,
                        device=self.finetune_config["device"],
                    )
                )
                fairness_score = utils.training.calculate_fairness(
                    self.initial_per_class_acc,
                    per_class_acc,
                    fairness_type=self.fairness_type,
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Before FT test loss: {test_loss}")
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test fairness score: {fairness_score}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test per-class accuracy: {per_class_acc}"
                )
            else:
                test_acc_top1, test_acc_top5, test_loss = utils.training.eval(
                    self.model,
                    test_loader=self.test_loader,
                    device=self.finetune_config["device"],
                )

                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Before FT test loss: {test_loss}")

            if self.fairness_eval_flag:
                (
                    before_ft_fitness,
                    before_ft_acc_top1,
                    before_ft_acc_top5,
                    before_ft_loss,
                    before_ft_per_class_acc,
                    before_ft_fairness_score,
                ) = self.fitness_with_fairness(test_model=self.model)

                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data top-1 accuracy: {before_ft_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data top-5 accuracy: {before_ft_acc_top5}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data loss: {before_ft_loss}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data fitness: {before_ft_fitness}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data fairness score: {before_ft_fairness_score}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data per-class accuracy: {before_ft_per_class_acc}"
                )
            else:
                (
                    before_ft_fitness,
                    before_ft_acc_top1,
                    before_ft_acc_top5,
                    before_ft_loss,
                ) = self.fitness(test_model=self.model)

                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data top-1 accuracy: {before_ft_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data top-5 accuracy: {before_ft_acc_top5}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data loss: {before_ft_loss}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT fitness_data fitness: {before_ft_fitness}"
                )

            if self.finetune:
                cur_save_dir = os.path.join(base_dir, "iterative_{}".format(i + 1))
                if not os.path.exists(cur_save_dir):
                    os.makedirs(cur_save_dir)
                logger.info(f"iter {i+1} finetuning:")
                logger2.info(f"iter {i+1} finetuning:")

                self.finetune_config["save_dir"] = cur_save_dir
                ft_model, ft_model_dict = utils.training.train_model(
                    model=self.model,
                    initial_per_class_acc=self.initial_per_class_acc_on_test,
                    initial_acc_top1=self.initial_acc_top1_on_test,
                    **self.finetune_config,
                )

            logger.info("\n")
            logger2.info("\n")

        pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(
            self.model, example_inputs=self.example_inputs
        )
        logger.info(
            "After Pruning | Pruned_macs: {} ({:.4f}) | Pruned_noarams: {} ({:.4f})".format(
                pruned_macs,
                pruned_macs / ori_macs,
                pruned_nparams,
                pruned_nparams / ori_nparams,
            )
        )
        logger.info(self.model)
