import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

import argparse
import torch
import torch_pruning as tp
import numpy as np
import random

import TP_pruner
import utils.tp_utils as tp_utils
from thop import profile

parser = argparse.ArgumentParser(description="ResNets for CelebA in pytorch")
# Architecture & Dataset
parser.add_argument("--arch", type=str, default="resnet18", help="model architecture")
parser.add_argument("--dataset", type=str, default="celeba", help="dataset")
parser.add_argument("--dataset_dir", type=str, default="../datasets/")
parser.add_argument(
    "--target_attr",
    type=str,
    default="Attractive",
    help="target-attr: Attractive, Blond_Hair, etc.",
)
parser.add_argument(
    "--sensitive_attr", type=str, default="Male", help="sensitive-attr: Male"
)
parser.add_argument("--num_class", type=int, default=2, help="num of classes")
parser.add_argument(
    "--num_sensitive_class", type=int, default=2, help="num of sensitive classes"
)
# Storage
parser.add_argument(
    "--load_dir",
    dest="load_dir",
    type=str,
    help="The directory used to load the models to be pruned",
    default="../prune_resouce/test",
)
parser.add_argument(
    "--save_dir",
    dest="save_dir",
    type=str,
    help="The directory used to save the models and logs",
    default="../prune_results/test",
)
# Prune
parser.add_argument(
    "--prune_methods", nargs="+", type=str, help="List of pruning methods"
)
parser.add_argument("--global_pruning", action="store_true", default=False)
parser.add_argument("--pruning_ratio", type=float, default=0.9)
parser.add_argument("--max_pruning_ratio", type=float, default=1.0)
parser.add_argument("--iterative_steps", default=9, type=int)
parser.add_argument("--layer_wise_imp", action="store_true", default=False)
parser.add_argument(
    "--batch_num_Hessian", type=int, default=10, help="batch nums for Hessian pruning"
)
parser.add_argument(
    "--pruning_ratio_list", nargs="+", type=float, help="List of pruning ratios"
)
# Basics
parser.add_argument(
    "--batch_size", default=64, type=int, help="batch size (default: 64)"
)
# Others
parser.add_argument("--gpu", type=int, default=1, help="cuda training")
parser.add_argument(
    "--random_seed", default=2023, type=int, help="seed for dataset split"
)
# Finetune
parser.add_argument("--adv_mode", action="store_true", default=False, help="using adversarial finetuning mode")
parser.add_argument("--w", type=float, default=0.5, help="debias weight")
parser.add_argument("--ft_epochs", type=int, default=20)
parser.add_argument("--p-lr", type=float, default=1e-4, help="predictor learning rate")
parser.add_argument("--a-lr", type=float, default=1e-4, help="adversary learning rate")
parser.add_argument("--w_decay", action="store_true", default=False)
parser.add_argument("--use_projection", action="store_true", default=False)
parser.add_argument(
    "--train_data_ratio",
    default=0,
    type=float,
    help="Rate of training data utilization",
)
parser.add_argument("--noadv_add_schedular", action="store_true", default=False)


parser.add_argument(
    "-p",
    "--print-freq",
    default=50,
    type=int,
    metavar="N",
    help="print frequency (default: 500)",
)


def main():
    args = parser.parse_args()
    if args.prune_methods is None:
        methods_list = ["l1"]
    else:
        methods_list = args.prune_methods

    if args.pruning_ratio_list is None:
        pruning_ratio_list = [args.pruning_ratio]
    else:
        pruning_ratio_list = args.pruning_ratio_list
    pruner_config = {
        "method": methods_list[0],
        "num_class": args.num_class,
        "gpu": args.gpu,
        "global_pruning": args.global_pruning,
        "layer_wise_imp": args.layer_wise_imp,
        "iterative_steps": args.iterative_steps,
        "pruning_ratio": args.pruning_ratio,
        "max_pruning_ratio": args.max_pruning_ratio,
        "batch_num_Hessian": args.batch_num_Hessian,
        "ignored_layers": ["model.fc"],
    }
    finetune_config = {
        "arch": args.arch,
        "num_class": args.num_class,
        "w": args.w,
        "num_sensitive_class": args.num_sensitive_class,
        "num_features": int(256 * (1 - args.pruning_ratio))
        if args.arch == "resnet18_half"
        else int(512 * (1 - args.pruning_ratio)),
        "gpu": args.gpu,
        "epochs": args.ft_epochs,
        "p_lr": args.p_lr,
        'a_lr': args.a_lr,
        "adv_mode": args.adv_mode,
        "save_every": 5,
        "noadv_add_schedular": args.noadv_add_schedular,
        "args": args,
    }

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    logger = tp_utils.set_logger(
        args, name=str(args.pruning_ratio) + "seed_{}".format(args.random_seed)
    )

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    logger.info(args.__dict__)

    if args.dataset == "celeba":
        example_inputs = torch.randn(1, 3, 224, 224).cuda(args.gpu)
    else:
        raise NotImplementedError("Not supported dataset")

    train_loader, valid_loader, test_loader, target_idx, sensitive_idx = (
        tp_utils.get_fairness_data(args)
    )
    pruner_config["target_idx"] = target_idx

    logger.info("Methods list: {methods_list}")
    for prune_method in methods_list:
        logger.info("Prune method: {}".format(prune_method))
        pruner_config["method"] = prune_method
        logger.info("Pruning ratio list: {pruning_ratio_list}")
        for pruning_ratio in pruning_ratio_list:
            pruner_config["pruning_ratio"] = pruning_ratio
            finetune_config["num_features"] = (
                int(256 * (1 - pruning_ratio))
                if args.arch == "resnet18_half"
                else int(512 * (1 - pruning_ratio))
            )

            cur_save_dir = args.save_dir + "/{}/{}/seed_{}/".format(
                prune_method, pruning_ratio, args.random_seed
            )
            if not os.path.exists(cur_save_dir):
                os.makedirs(cur_save_dir)
            model = tp_utils.get_model(args)

            model_input = torch.randn(1, 3, 224, 224)
            i_flops, i_params = profile(
                model, inputs=(model_input.cuda(args.gpu),), verbose=False
            )

            logger.info(
                "Original model profile: FLOPs: {} | Params: {}".format(
                    i_flops, i_params
                )
            )

            # prune
            # If you want iterative pruning, apply finetune_worker in TP_pruner.do() after pruner.step()
            pruner_config["prune_method"] = prune_method
            pruner_config["save_dir"] = cur_save_dir
            pruner = TP_pruner.get_pruner(model, example_inputs, pruner_config)
            pruned_model = TP_pruner.do(
                model,
                pruner,
                example_inputs,
                pruner_config,
                finetune=True,
                train_loader=train_loader,
                valid_loader=valid_loader,
                test_loader=test_loader,
                target_idx=target_idx,
                sensitive_idx=sensitive_idx,
                finetune_config=finetune_config,
            )
            print("Done")
            pruned_model.zero_grad()
            state_dict = tp.state_dict(pruned_model)
            torch.save(
                state_dict,
                os.path.join(
                    cur_save_dir,
                    "model_{}_{}_seed-{}.pth".format(
                        prune_method, pruning_ratio, args.random_seed
                    ),
                ),
            )


if __name__ == "__main__":
    main()
