"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""


import builtins
import collections
import importlib
import pickle  # nosec


# Regular unpickling is prone to arbitrary code execution attacks.
# This module implements the objects required by torch.load from a
# pickler module to be able to work, but the Unpickler object provided
# here is restricted with respect to the classes it can load, to safeguard
# against attacks described above.
# ** WARNING **: there is no final guarantee that using this module in
# conjunction with torch.load mitigates all possible routes of arbitrary code
# execution attacks possible while using Python's `pickle`.
# ** Only load the data you trust **

load = pickle.load


class Unpickler(pickle.Unpickler):
    safe_builtins = {
        'range',
        'complex',
        'set',
        'frozenset',
        'slice',
        'dict'
    }
    safe_collections = {'OrderedDict', 'defaultdict'}

    #pylint:disable=protected-access
    allowed_classes = {
        'torch': {'Tensor', 'FloatStorage', 'LongStorage', 'IntStorage'},
        'torch._utils': {'_rebuild_tensor', '_rebuild_tensor_v2', '_rebuild_parameter'},
        'torch.nn': {'Module'},
        'torch.optim.adam': {'Adam'},

        'nncf.api.compression': {'CompressionStage', 'CompressionLevel'},

        'numpy.core.multiarray': {'scalar'},
        'numpy': {'dtype'},

        '_codecs': {'encode'}
    }

    def find_class(self, module_name, class_name):
        # Only allow safe classes from builtins.
        if module_name in ["builtins", "__builtin__"] and class_name in Unpickler.safe_builtins:
            return getattr(builtins, class_name)
        if module_name == "collections" and class_name in Unpickler.safe_collections:
            return getattr(collections, class_name)
        for allowed_module_name, val in Unpickler.allowed_classes.items():
            if module_name == allowed_module_name and class_name in val:
                module = importlib.import_module(module_name)
                return getattr(module, class_name)

        # Forbid everything else.
        raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
                                     (module_name, class_name))
