"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import sys

import tensorflow as tf

from nncf.tensorflow.helpers.model_creation import create_compressed_model
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.utils.state import TFCompressionStateLoader
from examples.tensorflow.common.logger import logger
from examples.tensorflow.common.sample_config import create_sample_config
from examples.tensorflow.common.argparser import get_common_argument_parser
from examples.tensorflow.object_detection.models.model_selector import get_predefined_config as\
                                                                       get_predefined_od_config
from examples.tensorflow.object_detection.models.model_selector import get_model_builder as\
                                                                       get_model_od_builder
from examples.tensorflow.segmentation.models.model_selector import get_predefined_config as\
                                                                   get_predefined_seg_config
from examples.tensorflow.segmentation.models.model_selector import get_model_builder as\
                                                                   get_model_seg_builder
from examples.tensorflow.common.object_detection.checkpoint_utils import get_variables


class ModelType:
    object_detection = 'object_detection'
    segmentation = 'segmentation'


def get_config_and_model_type_from_argv(argv, parser):
    args = parser.parse_args(args=argv)

    config_from_json = create_sample_config(args, parser)
    if args.model_type == ModelType.object_detection:
        predefined_config = get_predefined_od_config(config_from_json.model)
    elif args.model_type == ModelType.segmentation:
        predefined_config = get_predefined_seg_config(config_from_json.model)
    else:
        raise RuntimeError('Wrong model type specified')

    predefined_config.update(config_from_json)
    return predefined_config, args.model_type


def load_checkpoint(checkpoint, ckpt_path):
    logger.info('Load from checkpoint is enabled')
    if tf.io.gfile.isdir(ckpt_path):
        path_to_checkpoint = tf.train.latest_checkpoint(ckpt_path)
        logger.info('Latest checkpoint: {}'.format(path_to_checkpoint))
    else:
        path_to_checkpoint = ckpt_path if tf.io.gfile.exists(ckpt_path + '.index') else None
        logger.info('Provided checkpoint: {}'.format(path_to_checkpoint))

    if not path_to_checkpoint:
        logger.info('No checkpoint detected')
        return 0

    logger.info('Checkpoint file {} found and restoring from checkpoint'.format(path_to_checkpoint))
    status = checkpoint.restore(path_to_checkpoint)
    status.expect_partial()
    logger.info('Completed loading from checkpoint')

    return None


def load_compression_state(ckpt_path: str):
    checkpoint = tf.train.Checkpoint(compression_state=TFCompressionStateLoader())
    load_checkpoint(checkpoint, ckpt_path)
    return checkpoint.compression_state.state


def od_checkpoint_saver(config):
    """
    Load object detection checkpoint and re-save it without optimizer (memory footprint is reduced).
    """
    model_builder = get_model_od_builder(config)
    model = model_builder.build_model()

    compression_state = load_compression_state(config.ckpt_path)
    compression_ctrl, compress_model = create_compressed_model(model, config.nncf_config, compression_state)

    checkpoint = tf.train.Checkpoint(model=compress_model,
                                     compression_state=TFCompressionState(compression_ctrl))
    load_and_save_checkpoint(checkpoint, config)


def seg_checkpoint_saver(config):
    """
    Load segmentation checkpoint and re-save it without optimizer (memory footprint is reduced).
    """
    model_builder = get_model_seg_builder(config)
    model = model_builder.build_model()

    compression_state = load_compression_state(config.ckpt_path)
    compression_ctrl, compress_model = create_compressed_model(model, config.nncf_config, compression_state)

    variables = get_variables(compress_model)
    checkpoint = tf.train.Checkpoint(variables=variables,
                                     compression_state=TFCompressionState(compression_ctrl),
                                     step=tf.Variable(0))
    load_and_save_checkpoint(checkpoint, config)


def load_and_save_checkpoint(checkpoint, config):
    """
    Load checkpoint and re-save it.
    """
    load_checkpoint(checkpoint, config.ckpt_path)
    if config.checkpoint_save_dir is None:
        config.checkpoint_save_dir = config.log_dir
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, config.checkpoint_save_dir, max_to_keep=None)
    save_path = checkpoint_manager.save()
    logger.info('Saved checkpoint: {}'.format(save_path))


def main(argv):
    parser = get_common_argument_parser(metrics_dump=False,
                                        resume_args=False,
                                        execution_args=False,
                                        epochs=False,
                                        precision=False,
                                        dataset_dir=False,
                                        dataset_type=False,
                                        log_dir=False,
                                        save_checkpoint_freq=False,
                                        export_args=False,
                                        print_freq=False)
    parser.add_argument(
        '--model-type',
        choices=[ModelType.object_detection,
                 ModelType.segmentation],
        help='Type of the model which checkpoint is being provided.',
        required=True)

    parser.add_argument(
        '--resume',
        metavar='PATH',
        type=str,
        default=None,
        dest='ckpt_path',
        help='Specifies the path to the checkpoint which should be optimized.',
        required=True)

    config, model_type = get_config_and_model_type_from_argv(argv, parser)

    if model_type == ModelType.object_detection:
        od_checkpoint_saver(config)
    if model_type == ModelType.segmentation:
        seg_checkpoint_saver(config)


if __name__ == '__main__':
    main(sys.argv[1:])
