import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision
import warnings

warnings.filterwarnings("ignore")
import torch_pruning as tp
import registry
import engine.utils as utils
from torch.utils.data.dataloader import default_collate


def main(args):
    # seed and random factors
    utils.set_random_seed(args.random_seed)
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    temp_save_dir = os.path.join(
        args.ckpt_save_dir, "base", args.dataset, args.arch, str(args.random_seed)
    )

    args.ckpt_save_dir = temp_save_dir
    if not os.path.exists(args.ckpt_save_dir):
        os.makedirs(args.ckpt_save_dir)

    # logger
    utils.set_logger(args)
    logger = logging.getLogger("train_logger")
    logger2 = logging.getLogger("result_logger")

    # data
    num_classes, train_dst, val_dst, input_size, _ = registry.get_dataset(
        args.dataset, data_root=args.data_root, args=args
    )
    print(num_classes, len(train_dst), len(val_dst), input_size)
    collate_fn = None
    example_inputs = torch.randn(*input_size)
    train_loader = torch.utils.data.DataLoader(
        train_dst,
        batch_size=args.batch_size,
        num_workers=args.workers,
        drop_last=True,
        pin_memory=True,
        shuffle=True,
        collate_fn=collate_fn,
    )
    test_loader = torch.utils.data.DataLoader(
        val_dst,
        batch_size=args.batch_size,
        num_workers=args.workers,
        shuffle=False,
        pin_memory=True,
    )

    # model
    pretrained = True if args.dataset == "cub200" else False
    model = registry.get_model(
        args.arch,
        num_classes=num_classes,
        pretrained=pretrained,
        target_dataset=args.dataset,
    )
    if args.dataset == "cub200":
        in_channel = model.fc.in_features
        model.fc = nn.Linear(in_channel, 200)

    if args.ckpt_load_dir:
        utils.load_model(model, args.ckpt_load_dir)
        logger.info(f"load model done from {args.ckpt_load_dir}")
    else:
        logger.info("prune from init model")

    logger.info(model)

    # move to gpu
    if args.label_smoothing > 0:
        criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        model = model.cuda(args.gpu)
        example_inputs = example_inputs.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    ops, params = tp.utils.count_ops_and_params(
        model,
        example_inputs=example_inputs,
    )
    logger.info("Params: {:.2f} M".format(params / 1e6))
    logger.info("ops: {:.2f} M".format(ops / 1e6))

    finetune_config = (
        {
            "epochs": args.ft_epochs,
            "lr": args.lr,
            "lr_step_size": args.lr_step_size,
            "lr_warmup_epochs": args.lr_warmup_epochs,
            "train_loader": train_loader,
            "test_loader": test_loader,
            "criterion": criterion,
            "save_dir": args.ckpt_save_dir,
            "device": args.gpu,
            "args": args,
            "pruner": None,
            "lr_decay_milestones": args.lr_decay_milestones,
            "save_every": args.save_every,
            "return_best": False,
        }
        if args.finetune
        else None
    )

    if args.finetune:
        ft_model, ft_model_dict = utils.training.train_model(
            model=model, **finetune_config
        )


if __name__ == "__main__":
    parser = utils.get_base_args_parser()
    args = parser.parse_args()
    main(args)
