import collections
import hashlib
import io
from abc import abstractmethod
from typing import Any, MutableMapping, Optional, Type

import base58
import dill

ndarray: Optional[Type]
try:
    from numpy import ndarray
except ModuleNotFoundError:
    ndarray = None

TorchTensor: Optional[Type]
try:
    from torch import Tensor as TorchTensor
except ModuleNotFoundError:
    TorchTensor = None


class CustomDetHash:
    

    @abstractmethod
    def det_hash_object(self) -> Any:
        
        raise NotImplementedError()


class DetHashFromInitParams(CustomDetHash):
    

    _det_hash_object: Any

    def __new__(cls, *args, **kwargs):
        super_new = super(DetHashFromInitParams, cls).__new__
        if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
            instance = super_new(cls)
        else:
            instance = super_new(cls, *args, **kwargs)
        instance._det_hash_object = (args, kwargs)
        return instance

    def det_hash_object(self) -> Any:
        
        return self._det_hash_object


class DetHashWithVersion(CustomDetHash):
    

    VERSION: Optional[str] = None

    def det_hash_object(self) -> Any:
        
        if self.VERSION is not None:
            return self.VERSION, self
        else:
            return None  


_PICKLE_PROTOCOL = 4


class _DetHashPickler(dill.Pickler):
    def __init__(self, buffer: io.BytesIO):
        super().__init__(buffer, protocol=_PICKLE_PROTOCOL)

        
        
        
        
        
        
        
        
        
        self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter()

    def save(self, obj, save_persistent_id=True):
        self.recursively_pickled_ids[id(obj)] += 1
        super().save(obj, save_persistent_id)
        self.recursively_pickled_ids[id(obj)] -= 1

    def persistent_id(self, obj: Any) -> Any:
        if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1:
            det_hash_object = obj.det_hash_object()
            if det_hash_object is not None:
                return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object
            else:
                return None
        elif isinstance(obj, type):
            return obj.__module__, obj.__qualname__
        elif callable(obj):
            if hasattr(obj, "__module__") and hasattr(obj, "__qualname__"):
                return obj.__module__, obj.__qualname__
            else:
                return None
        elif ndarray is not None and isinstance(obj, ndarray):
            
            return obj.dumps()
        elif TorchTensor is not None and isinstance(obj, TorchTensor):
            
            import torch

            with io.BytesIO() as buffer:
                torch.save(obj, buffer, pickle_protocol=_PICKLE_PROTOCOL)
                return buffer.getvalue()
        else:
            return None


def det_hash(o: Any) -> str:
    
    m = hashlib.blake2b()
    with io.BytesIO() as buffer:
        pickler = _DetHashPickler(buffer)
        pickler.dump(o)
        m.update(buffer.getbuffer())
        return base58.b58encode(m.digest()).decode()
