import logging
import FPVE
import registry
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import engine.utils as utils
import torchvision
from torch.utils.data.dataloader import default_collate


def main(args):
    # seed and random factors
    utils.set_random_seed(args.random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # logger and ckpt_path
    utils.set_logger(args)
    logger = logging.getLogger("train_logger")
    logger2 = logging.getLogger("result_logger")

    # data
    num_classes, train_dst, val_dst, input_size, FPVE_fitness_dst = (
        registry.get_dataset(
            args.dataset,
            data_root=args.data_root,
            FPVE_fitness_data_ratio=args.FPVE_fitness_data_ratio,
            args=args,
        )
    )

    collate_fn = None
    example_inputs = torch.randn(*input_size)
    train_loader = torch.utils.data.DataLoader(
        train_dst,
        batch_size=args.batch_size,
        num_workers=args.workers,
        drop_last=True,
        pin_memory=True,
        shuffle=True,
        collate_fn=collate_fn,
    )
    test_loader = torch.utils.data.DataLoader(
        val_dst,
        batch_size=args.batch_size,
        num_workers=args.workers,
        shuffle=False,
        pin_memory=True,
    )
    if FPVE_fitness_dst is not None:
        FPVE_fitness_loader = torch.utils.data.DataLoader(
            FPVE_fitness_dst,
            batch_size=args.batch_size,
            num_workers=args.workers,
            shuffle=False,
            pin_memory=True,
        )
    # model
    pretrained = False
    model = registry.get_model(
        args.arch,
        num_classes=num_classes,
        pretrained=pretrained,
        target_dataset=args.dataset,
    )
    if args.dataset == "cub200":
        in_channel = model.fc.in_features
        model.fc = nn.Linear(in_channel, 200)

    if args.ckpt_load_dir:
        utils.load_model(model, args.ckpt_load_dir)
        logger.info(f"load model done from {args.ckpt_load_dir}")
    else:
        logger.info("prune from init model")

    # move to gpu
    if args.label_smoothing > 0:
        criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        model = model.cuda(args.gpu)
        example_inputs = example_inputs.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    logger.info("START:")
    alg = FPVE.FPVE(
        model,
        train_loader,
        test_loader,
        example_inputs,
        criterion,
        args,
        FPVE_fitness_loader=FPVE_fitness_loader,
    )
    alg.run()


if __name__ == "__main__":
    parser = utils.get_FPVE_args_parser()
    args = parser.parse_args()
    main(args)
