from argparse import ArgumentParser

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms

from dataset import CUBFSLDataModule, CARFSLDataModule, DOGFSLDataModule,AircraftFSLDataModule
from model import FSLFeatModel
# from model_resnet12 import  FSLFeatModel


def main(hparams):
    seed_everything(hparams.seed)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    ckpt_monitor = ModelCheckpoint(save_last=True, monitor='val_acc',
                                   save_top_k=3, mode='max',
                                   filename='{epoch:04d}-{val_acc:.2%}')
    trainer = Trainer.from_argparse_args(hparams, callbacks=[lr_monitor, ckpt_monitor],
                                         auto_lr_find=True)


    # set up normalization
    if hparams.data_name == 'mini':
        mean_pix = [x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
        std_pix = [x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
        normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
    elif hparams.data_name == 'fc100':
        mean_pix = [x / 255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
        std_pix = [x / 255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
        normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.Resize((hparams.input_size, hparams.input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomPerspective(),
        transforms.ToTensor(),
        normalize
    ])
    test_transform = transforms.Compose([
        transforms.Resize((hparams.input_size, hparams.input_size)),
        transforms.ToTensor(),
        normalize
    ])

    if hparams.data_name == 'cub':
        dm = CUBFSLDataModule(
            hparams.data_dir, hparams.split_dir, hparams.iterations_tr, hparams.iterations_val,
            hparams.iterations_test, hparams.classes_per_it_tr, hparams.num_support_tr,
            hparams.num_query_tr, hparams.classes_per_it_val,
            hparams.num_support_val, hparams.num_query_val, hparams.train_mode,
            train_transform, test_transform, hparams.batch_size)
    elif hparams.data_name == 'CAR':
        dm = CARFSLDataModule(
            hparams.data_dir, hparams.split_dir, hparams.iterations_tr, hparams.iterations_val,
            hparams.iterations_test, hparams.classes_per_it_tr, hparams.num_support_tr,
            hparams.num_query_tr, hparams.classes_per_it_val,
            hparams.num_support_val, hparams.num_query_val, hparams.train_mode,
            train_transform, test_transform, hparams.batch_size)
    elif hparams.data_name == 'DOG':
        dm = DOGFSLDataModule(
            hparams.data_dir, hparams.split_dir, hparams.iterations_tr, hparams.iterations_val,
            hparams.iterations_test, hparams.classes_per_it_tr, hparams.num_support_tr,
            hparams.num_query_tr, hparams.classes_per_it_val,
            hparams.num_support_val, hparams.num_query_val, hparams.train_mode,
            train_transform, test_transform, hparams.batch_size)
    elif hparams.data_name == 'Aircraft':
        dm = AircraftFSLDataModule(
            hparams.data_dir, hparams.split_dir, hparams.iterations_tr, hparams.iterations_val,
            hparams.iterations_test, hparams.classes_per_it_tr, hparams.num_support_tr,
            hparams.num_query_tr, hparams.classes_per_it_val,
            hparams.num_support_val, hparams.num_query_val, hparams.train_mode,
            train_transform, test_transform, hparams.batch_size)
    else:
        raise NotImplementedError(f'{hparams.data_name} dataset not implemented')  # TODO: omniglot

    if hparams.test_ckpt != 'none':
        # model = FSLFeatModel.load_from_checkpoint(hparams.test_ckpt)
        # model.update_evaluate_param(args.num_support_val, args.classes_per_it_val)
        model = FSLFeatModel(hparams)
        model.load_all(baseline_ckpt=hparams.test_ckpt)
        trainer.test(model, datamodule=dm)
    else:
        model = FSLFeatModel(hparams)
        if hparams.baseline_backbone_ckpt != 'none':
            model.load_backbone(hparams.baseline_backbone_ckpt, hparams.backbone_feti_pretrained)

        if hparams.only_find_lr:
            trainer.tune(model, datamodule=dm)
        else:
            trainer.fit(model, datamodule=dm)
            trainer.test(model, datamodule=dm)


if __name__ == '__main__':
    parser = ArgumentParser()

    # add program arguments
    parser.add_argument('--default_root_dir', type=str, default='.')
    parser.add_argument('--seed', type=int, default=1234)

    # add model specific args
    parser = FSLFeatModel.add_model_arch_args(parser)
    parser = FSLFeatModel.add_model_loss_args(parser)
    parser = FSLFeatModel.add_model_train_args(parser)

    # add data specific args
    parser.add_argument('--data_name', type=str, default='CUB')
    parser.add_argument('--data_dir', type=str, default='')
    parser.add_argument('--split_dir', type=str, default='')
    parser.add_argument('--iterations_tr', type=int, default=100)
    parser.add_argument('--iterations_val', type=int, default=200)
    parser.add_argument('--iterations_test', type=int, default=600)
    parser.add_argument('--classes_per_it_tr', type=int, default=20)
    parser.add_argument('--num_support_tr', type=int, default=5)
    parser.add_argument('--num_query_tr', type=int, default=5)
    parser.add_argument('--classes_per_it_val', type=int, default=5)
    parser.add_argument('--num_support_val', type=int, default=5)
    parser.add_argument('--num_query_val', type=int, default=15)
    parser.add_argument('--train_mode', type=str, default='batch')
    parser.add_argument('--batch_size', type=int, default=64)

    # add trainer options
    parser.add_argument('--only_find_lr', type=bool, default=False)
    parser.add_argument('--deterministic', type=bool, default=False)
    parser.add_argument('--max_epochs', type=int, default=20)
    parser.add_argument('--gpus', type=int, default=1)
    parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
    parser.add_argument('--baseline_backbone_ckpt', type=str, default='none')
    parser.add_argument('--backbone_feti_pretrained', type=bool, default=False)

    parser.add_argument('--test_ckpt', type=str, default='none')


    args = parser.parse_args()

    main(args)