import logging
import os
import torch
import torch_pruning as tp
from functools import partial
import engine.utils as utils


def get_pruner(model, example_inputs, config):
    if config["prune_method"] == "random":
        imp = tp.importance.RandomImportance()
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "l1":
        imp = tp.importance.MagnitudeImportance(p=1)
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "l2":
        imp = tp.importance.MagnitudeImportance(p=2)
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "FPGM":
        imp = tp.importance.FPGMImportance(p=2)
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "Hessian":
        imp = tp.importance.HessianImportance()
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    else:
        raise NotImplementedError

    if config["layer_wise_imp"]:
        imp.group_reduction = "first"
    logger = logging.getLogger("train_logger")

    if config["arch"] in ["resnet18", "resnet34", "resnet50", "resnet56", "resnet110"]:

        def get_block_num(config_arch):
            arch_to_blocks = {
                "resnet18": [2, 2, 2, 2],
                "resnet34": [3, 4, 6, 3],
                "resnet50": [3, 4, 6, 3],
                "resnet56": [9, 9, 9],
                "resnet110": [18, 18, 18],
            }

            if config_arch not in arch_to_blocks:
                raise ValueError(f"Unsupported architecture: {config_arch}")

            return arch_to_blocks[config_arch]

        def get_layer_objects(model, layer_names):
            layer_objects = []
            for layer_name in layer_names:
                layer_obj = model
                for part in layer_name.split("."):
                    layer_obj = getattr(layer_obj, part)
                layer_objects.append(layer_obj)
            return layer_objects

        def get_ignored_layers(model, config):
            ignored_layers = []

            for m in model.modules():
                if (
                    isinstance(m, torch.nn.Linear)
                    and m.out_features == config["num_class"]
                ):
                    ignored_layers.append(m)  # DO NOT prune the final classifier!

            return ignored_layers

        ignored_layers = get_ignored_layers(model, config)

        pruner = pruner_entry(
            model,
            example_inputs,
            importance=imp,
            iterative_steps=config["iterative_steps"],
            pruning_ratio=config["pruning_ratio"],
            max_pruning_ratio=config["max_pruning_ratio"],
            ignored_layers=ignored_layers,
        )
        post_process_func = None
    else:
        raise NotImplementedError(f"Unsupported architecture: {config['arch']}")

    logger.info("ignored_layers:")
    logger.info(ignored_layers)
    return pruner, post_process_func


def do(
    model,
    pruner,
    example_inputs,
    config,
    finetune=False,
    finetune_config=None,
    post_process_func=None,
):
    logger = logging.getLogger("train_logger")
    logger2 = logging.getLogger("result_logger")

    initial_per_class_acc_on_test = None
    test_acc_top1_on_test = None

    model.eval()
    ori_macs, ori_nparams = tp.utils.count_ops_and_params(
        model, example_inputs=example_inputs
    )
    logger.info(
        "Before Pruning | Ori_macs: {} | Ori_params: {}".format(ori_macs, ori_nparams)
    )

    logger.info(model)
    logger.info("\n")
    base_dir = finetune_config["save_dir"]

    for i in range(config["iterative_steps"]):
        if finetune:
            if finetune_config["fairness_eval_flag"]:
                test_acc_top1, test_acc_top5, test_loss, per_class_acc = (
                    utils.training.fairness_eval(
                        model,
                        test_loader=finetune_config["test_loader"],
                        device=finetune_config["device"],
                    )
                )
                if initial_per_class_acc_on_test is None:
                    initial_per_class_acc_on_test = per_class_acc
                    test_acc_top1_on_test = test_acc_top1
                fairness_score = utils.training.calculate_fairness(
                    initial_per_class_acc_on_test,
                    per_class_acc,
                    fairness_type=finetune_config["fairness_type"],
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Initial test loss: {test_loss}")
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test fairness score: {fairness_score}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test per-class accuracy: {per_class_acc}"
                )
            else:
                test_acc_top1, test_acc_top5, test_loss = utils.training.eval(
                    model,
                    test_loader=finetune_config["test_loader"],
                    device=finetune_config["device"],
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Initial test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Initial test loss: {test_loss}")

        logger.info("Pruning...")
        if config["prune_method"] in ["Hessian"]:
            if config["train_loader"] == None:
                raise AttributeError(
                    "When using Hessian imp in pruner, data are needed"
                )
            for j, (inputs, labels) in enumerate(config["train_loader"]):
                if j < config["batch_num_Hessian"]:
                    labels = labels.cuda(config["gpu"])
                    inputs_var = inputs.cuda(config["gpu"])
                    labels_var = labels.cuda(config["gpu"])
                    output = model(inputs_var)
                    loss = torch.nn.functional.cross_entropy(
                        output, labels_var, reduction="none"
                    ).cuda(config["gpu"])
                    cnt = 0
                    for l in loss:
                        cnt += 1
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        pruner.importance.accumulate_grad(model)
            for g in pruner.step(interactive=True):
                g.prune()
        else:
            pruner.step()
        # pruner.step()
        if post_process_func is not None:
            post_process_func(model, pruner)

        macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
        logger.info(
            "  Iter %d/%d, Params: %.2f M => %.2f M (%.2f%%)"
            % (
                i + 1,
                config["iterative_steps"],
                ori_nparams / 1e6,
                nparams / 1e6,
                (nparams / ori_nparams) * 100,
            )
        )
        logger.info(
            "  Iter %d/%d, MACs: %.2f G => %.2f G (%.2f%%)"
            % (
                i + 1,
                config["iterative_steps"],
                ori_macs / 1e9,
                macs / 1e9,
                (macs / ori_macs) * 100,
            )
        )
        logger.info(model)

        logger2.info(
            "  Iter %d/%d, Params: %.2f M => %.2f M (%.2f%%)"
            % (
                i + 1,
                config["iterative_steps"],
                ori_nparams / 1e6,
                nparams / 1e6,
                (nparams / ori_nparams) * 100,
            )
        )
        logger2.info(
            "  Iter %d/%d, MACs: %.2f G => %.2f G (%.2f%%)"
            % (
                i + 1,
                config["iterative_steps"],
                ori_macs / 1e9,
                macs / 1e9,
                (macs / ori_macs) * 100,
            )
        )
        logger2.info(model)

        if finetune:
            if finetune_config["fairness_eval_flag"]:
                test_acc_top1, test_acc_top5, test_loss, per_class_acc = (
                    utils.training.fairness_eval(
                        model,
                        test_loader=finetune_config["test_loader"],
                        device=finetune_config["device"],
                    )
                )
                fairness_score = utils.training.calculate_fairness(
                    initial_per_class_acc_on_test,
                    per_class_acc,
                    fairness_type=finetune_config["fairness_type"],
                )

                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Before FT test loss: {test_loss}")
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test fairness score: {fairness_score}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test per-class accuracy: {per_class_acc}"
                )
            else:
                test_acc_top1, test_acc_top5, test_loss = utils.training.eval(
                    model,
                    test_loader=finetune_config["test_loader"],
                    device=finetune_config["device"],
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-1 accuracy: {test_acc_top1}"
                )
                logger.info(
                    f"Iterative_steps:{i+1}, Before FT test top-5 accuracy: {test_acc_top5}"
                )
                logger.info(f"Iterative_steps:{i+1}, Before FT test loss: {test_loss}")

        if finetune:
            cur_save_dir = os.path.join(base_dir, "iterative_{}".format(i + 1))
            if not os.path.exists(cur_save_dir):
                os.makedirs(cur_save_dir)
            logger.info(f"iter {i+1} finetuning:")
            logger2.info(f"iter {i+1} finetuning:")

            finetune_config["save_dir"] = cur_save_dir
            ft_model, ft_model_dict = utils.training.train_model(
                model=model,
                initial_per_class_acc=initial_per_class_acc_on_test,
                initial_acc_top1=test_acc_top1_on_test,
                **finetune_config,
            )

        logger.info("\n")
        logger2.info("\n")

    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(
        model, example_inputs=example_inputs
    )
    logger.info(
        "After Pruning | Pruned_macs: {} ({:.4f}) | Pruned_noarams: {} ({:.4f})".format(
            pruned_macs,
            pruned_macs / ori_macs,
            pruned_nparams,
            pruned_nparams / ori_nparams,
        )
    )
    logger.info(model)
    return model
