import logging
import os
import torch
import torch_pruning as tp
from functools import partial
import prune_finetune
from utils.train_eval import fairness_validate


def get_pruner(model, example_inputs, config):
    if config["prune_method"] == "random":
        imp = tp.importance.RandomImportance()
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "l1":
        imp = tp.importance.MagnitudeImportance(p=1)
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "l2":
        imp = tp.importance.MagnitudeImportance(p=2)
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "FPGM":
        imp = tp.importance.FPGMImportance(p=2)
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    elif config["prune_method"] == "Hessian":
        imp = tp.importance.HessianImportance()
        pruner_entry = partial(
            tp.pruner.MagnitudePruner, global_pruning=config["global_pruning"]
        )
    else:
        raise NotImplementedError

    logger = logging.getLogger("train_logger")
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == config["num_class"]:
            ignored_layers.append(m)  # DO NOT prune the final classifier!

    logger.info("ignored_layers:")
    logger.info(ignored_layers)

    if config["layer_wise_imp"]:
        imp.group_reduction = "first"

    pruner = pruner_entry(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=config["iterative_steps"],
        pruning_ratio=config["pruning_ratio"],
        max_pruning_ratio=config["max_pruning_ratio"],
        ignored_layers=ignored_layers,
    )
    return pruner


def do(
    model,
    pruner,
    example_inputs,
    config,
    finetune=False,
    train_loader=None,
    valid_loader=None,
    test_loader=None,
    target_idx=None,
    sensitive_idx=None,
    finetune_config=None,
):
    logger = logging.getLogger("train_logger")
    model.eval()
    ori_macs, ori_nparams = tp.utils.count_ops_and_params(
        model, example_inputs=example_inputs
    )
    logger.info(
        "Before Pruning | Ori_macs: {} | Ori_noarams: {}".format(ori_macs, ori_nparams)
    )
    logger.info(model)
    logger.info("Pruning...")
    acc = []
    DI = []
    DEO = []
    valid_acc = []
    valid_DI = []
    valid_DEO = []

    for i in range(config["iterative_steps"]):
        if config["method"] == "Hessian":
            if train_loader == None:
                raise AttributeError(
                    "When using Hessian imp in pruner, data are needed"
                )
            for j, (inputs, labels) in enumerate(train_loader):
                if j < config["batch_num_Hessian"]:
                    labels = labels.cuda(config["gpu"])
                    inputs_var = inputs.cuda(config["gpu"])
                    labels_var = labels[:, config["target_idx"]].cuda(config["gpu"])
                    output = model(inputs_var)
                    loss = torch.nn.functional.cross_entropy(
                        output, labels_var, reduction="none"
                    ).cuda(config["gpu"])
                    cnt = 0
                    for l in loss:
                        cnt += 1
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        pruner.importance.accumulate_grad(model)
            for g in pruner.step(interactive=True):
                g.prune()

        else:
            pruner.step()

        macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
        logger.info(
            "  Iter %d/%d, Params: %.2f M => %.2f M"
            % (i + 1, config["iterative_steps"], ori_nparams / 1e6, nparams / 1e6)
        )
        logger.info(
            "  Iter %d/%d, MACs: %.2f G => %.2f G"
            % (i + 1, config["iterative_steps"], ori_macs / 1e9, macs / 1e9)
        )
        logger.info(model)

        if finetune:
            total_feature = model.layer4[0].conv2.out_channels
            logger.info("feature num: {}".format(total_feature))
            finetune_config["num_features"] = int(total_feature)
            cur_save_dir = config["save_dir"] + "/iterative_{}/".format(i + 1)
            if not os.path.exists(cur_save_dir):
                os.makedirs(cur_save_dir)
            logger.info("iter i: {}".format(i))
            finetune_worker = prune_finetune.Finetune(
                train_loader,
                valid_loader,
                test_loader,
                target_idx,
                sensitive_idx,
                finetune_config,
            )

            best_model, best_model_dict = finetune_worker.do(
                model, cur_save_dir, return_best_valid=False, FPVE_flag=False
            )
            tp.load_state_dict(model, state_dict=best_model_dict)

            criterion = torch.nn.CrossEntropyLoss()
            acc_val_top1, DI_val, DEO_val = fairness_validate(
                valid_loader,
                model,
                criterion,
                finetune_config["args"],
                target_idx,
                sensitive_idx,
                mode="Valid",
            )
            acc_test_top1, DI_test, DEO_test = fairness_validate(
                test_loader,
                model,
                criterion,
                finetune_config["args"],
                target_idx,
                sensitive_idx,
                mode="Test",
            )
            acc.append(acc_test_top1)
            DI.append(DI_test)
            DEO.append(DEO_test)
            valid_acc.append(acc_val_top1)
            valid_DI.append(DI_val)
            valid_DEO.append(DEO_val)

            logger.info(f"Test set ACC:{acc_test_top1}")
            logger.info(f"Test set DEO:{DEO_test}")
            logger.info(f"Test set DI:{DI_test}")
            logger.info(f"Valid set ACC:{acc_val_top1}")
            logger.info(f"Valid set DEO:{DEO_val}")
            logger.info(f"Valid set DI:{DI_val}")

    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(
        model, example_inputs=example_inputs
    )
    logger.info(model)
    logger.info(
        "After Pruning | Pruned_macs: {} ({:.4f}) | Pruned_noarams: {} ({:.4f})".format(
            pruned_macs,
            pruned_macs / ori_macs,
            pruned_nparams,
            pruned_nparams / ori_nparams,
        )
    )

    logger.info(f"Test set ACC:{acc}")
    logger.info(f"Test set DEO:{DEO}")
    logger.info(f"Test set DI:{DI}")
    logger.info(f"Valid set ACC:{valid_acc}")
    logger.info(f"Valid set DEO:{valid_DEO}")
    logger.info(f"Valid set DI:{valid_DI}")
    return model
