"""
Linear Probe with sklearn Logistic Regression or linear model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import logging

import numpy as np
import random

from vision_benchmark.utils import comm, create_logger
from vision_benchmark.evaluation import construct_dataloader, full_model_finetune
from vision_benchmark.config import config, update_config
# These 2 lines are a walk-around for "Too many open files error". Refer: https://github.com/pytorch/pytorch/issues/11201
import torch.multiprocessing
from vision_benchmark.common.utils import log_arg_env_config, submit_predictions

torch.multiprocessing.set_sharing_strategy('file_system')

MULTILABEL_DATASETS = {"chestx-ray8"}


def add_linear_probing_args(parser):
    parser.add_argument('--ds', required=False, help='Evaluation dataset configure file name.', type=str)
    parser.add_argument('--model', required=True, help='Evaluation model configure file name', type=str)
    parser.add_argument('--submit-predictions', help='submit predictions and model info to leaderboard.', default=False, action='store_true')
    parser.add_argument('--submit-by', help='Person who submits the results.', type=str)

    parser.add_argument('--no-tuning', help='No hyperparameter-tuning.', default=False, type=lambda x:x.lower()=="true")
    parser.add_argument('--emulate-zeroshot', help='Emulate zero shot learning.', default=False, type=str)
    parser.add_argument('--l2', help='(Inverse) L2 regularization strength. This option is only useful when option --no-tuning is True.', default=0.316, type=float)
    parser.add_argument('--lr', help='Test with a specific learning rate. This option is only useful when option --no-tuning is True.', default=0.001, type=float)
    parser.add_argument('--run', help='Run id', default=1, type=int)
    parser.add_argument('--fix_seed', help='Fix the random seed. [-1] not fixing the seeds', default=0, type=int)

    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)    

def main():
    parser = argparse.ArgumentParser(description='Test a classification model, with linear probing.')
    add_linear_probing_args(parser)
    args = parser.parse_args()

    args.cfg = args.ds
    update_config(config, args)
    args.cfg = args.model
    update_config(config, args)
    config.defrost()
    config.NAME = ''
    config.freeze()

    if args.submit_predictions:
        assert args.submit_by

    if args.fix_seed != -1:
        random.seed(args.fix_seed)
        np.random.seed(args.fix_seed)
        torch.manual_seed(args.fix_seed)
        torch.cuda.manual_seed_all(args.fix_seed)

    if args.emulate_zeroshot:
        args.no_tuning = True
        config.defrost()
        config.TRAIN.END_EPOCH = 1
        config.TRAIN.EXTRA_FINAL_TRAIN_EPOCH = 0
        config.DATASET.NUM_SAMPLES_PER_CLASS = 0
        config.TRAIN.EMULATE_ZERO_SHOT = True
        config.freeze()

    n_samples = str(config.DATASET.NUM_SAMPLES_PER_CLASS) if config.DATASET.NUM_SAMPLES_PER_CLASS >= 0 else 'full'
    exp_name = 'linear_probe_' + n_samples

    if config.DATASET.NUM_SAMPLES_PER_CLASS == 1:
        config.defrost()
        config.DATASET.NUM_SAMPLES_PER_CLASS = 2
        config.DATASET.MERGE_TRAIN_VAL_FINAL_RUN = False
        config.freeze()

    # Follow MAE's design choice: not using global pool in linear probe
    if config.MODEL.NAME.startswith('mae_'):
        config.defrost()
        config.MODEL.SPEC.GLOBAL_POOL = False
        config.freeze()

    final_output_dir = create_logger(config, exp_name)
    if comm.is_main_process():
        log_arg_env_config(args, config, final_output_dir)

    if config.DATASET.DATASET == 'patch-camelyon' and config.DATASET.NUM_SAMPLES_PER_CLASS == -1:
        # deal with patch camelyon large dataset (search using 10000-shot subset, final run with the full dataset)
        logging.info(f'Detecting large dataset with {config.DATASET.NUM_SAMPLES_PER_CLASS}-shot.')
        config.defrost()
        config.DATASET.NUM_SAMPLES_PER_CLASS = 10000
        config.freeze()
        logging.info(f'Used the subset ({config.DATASET.NUM_SAMPLES_PER_CLASS}-shot) to train the model.')

    # Run linear probe
    train_dataloader, val_dataloader, test_dataloader = construct_dataloader(config)

    full_model_finetune(train_dataloader, val_dataloader, test_dataloader, args.no_tuning, args.lr, args.l2, config)

    test_predictions = None  # submission not supported yet
    dataset_info = None

    if args.submit_predictions and dataset_info and test_predictions:
        submit_predictions(test_predictions.tolist(), args.submit_by, config, 'linear_probing', dataset_info.type)


if __name__ == '__main__':
    main()
