import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

import torch
import torch.utils.data
import torchvision
import torch.nn as nn
import torch_pruning as tp
import logging
import TP_pruner
import registry
import engine.utils as utils


def main(args):
    # seed and random factors
    utils.set_random_seed(args.random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    temp_save_dir = os.path.join(
        args.ckpt_save_dir,
        "prune",
        args.dataset,
        args.arch,
        args.prune_method,
        f"layer_wise_imp_{args.layer_wise_imp}",
        f"{args.pruning_ratio}-{args.iterative_steps}",
        str(args.random_seed),
    )
    args.ckpt_save_dir = temp_save_dir

    if not os.path.exists(args.ckpt_save_dir):
        os.makedirs(args.ckpt_save_dir)

    # logger and ckpt_path
    utils.set_logger(args)
    logger = logging.getLogger("train_logger")
    logger2 = logging.getLogger("result_logger")

    # data
    num_classes, train_dst, val_dst, input_size, _ = registry.get_dataset(
        args.dataset, data_root=args.data_root, args=args
    )

    collate_fn = None
    example_inputs = torch.randn(*input_size)
    train_loader = torch.utils.data.DataLoader(
        train_dst,
        batch_size=args.batch_size,
        num_workers=args.workers,
        drop_last=True,
        pin_memory=True,
        shuffle=True,
        collate_fn=collate_fn,
    )
    test_loader = torch.utils.data.DataLoader(
        val_dst,
        batch_size=args.batch_size,
        num_workers=args.workers,
        shuffle=False,
        pin_memory=True,
    )

    # model
    pretrained = True if args.dataset == "imagenet" else False
    model = registry.get_model(
        args.arch,
        num_classes=num_classes,
        pretrained=pretrained,
        target_dataset=args.dataset,
    )
    if args.dataset == "cub200":
        in_channel = model.fc.in_features
        model.fc = nn.Linear(in_channel, 200)

    if args.ckpt_load_dir:
        utils.load_model(model, args.ckpt_load_dir)
        logger.info(f"load model done from {args.ckpt_load_dir}")
    else:
        logger.info("prune from init model")

    # move to gpu
    if args.label_smoothing > 0:
        criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        model = model.cuda(args.gpu)
        example_inputs = example_inputs.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    pruner_config = {
        "arch": args.arch,
        "prune_method": args.prune_method,
        "num_class": num_classes,
        "gpu": args.gpu,
        "global_pruning": args.global_pruning,
        "layer_wise_imp": args.layer_wise_imp,
        "iterative_steps": args.iterative_steps,
        "pruning_ratio": args.pruning_ratio,
        "max_pruning_ratio": args.max_pruning_ratio,
        "batch_num_Hessian": args.batch_num_Hessian,
        "train_loader": train_loader if args.prune_method in ["Hessian"] else None,
    }
    finetune_config = {
        "epochs": args.ft_epochs,
        "lr": args.lr,
        "lr_step_size": args.lr_step_size,
        "lr_warmup_epochs": args.lr_warmup_epochs,
        "train_loader": train_loader,
        "test_loader": test_loader,
        "criterion": criterion,
        "save_dir": args.ckpt_save_dir,
        "device": args.gpu,
        "args": args,
        "pruner": None,
        "lr_decay_milestones": args.lr_decay_milestones,
        "save_every": args.save_every,
        "return_best": False,
        "fairness_eval_flag": args.fairness_eval_flag,
        "fairness_type": args.fairness_type,
        "scalar": args.scalar,
    }

    logger.info(
        f"Prune method: {args.prune_method}, Pruning ratio : {args.pruning_ratio}, Pruning steps : {args.iterative_steps}"
    )

    # prune
    # If you want iterative pruning, apply finetune_worker in TP_pruner.do() after pruner.step()
    pruner, post_process_func = TP_pruner.get_pruner(
        model, example_inputs, pruner_config
    )
    pruned_model = TP_pruner.do(
        model,
        pruner,
        example_inputs,
        pruner_config,
        finetune=args.finetune,
        finetune_config=finetune_config,
        post_process_func=post_process_func,
    )
    print("Done")


if __name__ == "__main__":
    parser = utils.get_tp_args_parser()
    args = parser.parse_args()
    main(args)
