import json
import torch
from ml_collections import ConfigDict
import hashlib

class TensorHash(object):
    # Adapted from https://stackoverflow.com/questions/74805446/how-to-hash-a-pytorch-tensor
    MULTIPLIER = 6364136223846793005
    INCREMENT = 1
    MODULUS = 2**64
    def __init__(self):
        pass

    @staticmethod
    def hash_str(x, len=10):
        return str(hex(int(hashlib.sha256(x.encode()).hexdigest(), 16))[-1 * len:])

    @staticmethod
    def hash_tensor(x: torch.Tensor, return_hex=True) -> torch.Tensor:
        assert x.dtype == torch.int64
        while x.ndim > 0:
            x = TensorHash._reduce_last_axis(x)
        if return_hex:
            return hex(x.item())
        return x.item()
    
    @staticmethod
    def hash_tensor_dict(tensor_dict, return_hex=True):
        serialized_dict = (".".join([k + str(TensorHash.hash_tensor(v, return_hex=False)) 
                                     for k, v in tensor_dict.items() if k != "config"]))
        return (TensorHash.hash_str(serialized_dict)) if return_hex else TensorHash.hash_str(serialized_dict)

    @staticmethod
    @torch.no_grad()
    def _reduce_last_axis(x: torch.Tensor) -> torch.Tensor:
        assert x.dtype == torch.int64
        acc = torch.zeros_like(x[..., 0])
        for i in range(x.shape[-1]):
            acc *= TensorHash.MULTIPLIER
            acc += TensorHash.INCREMENT
            acc += x[..., i]
            # acc %= MODULUS  # Not really necessary.
        return acc
    
    @staticmethod
    def hash_model_params(model_name, model_params):
        return (TensorHash.hash_str(json.dumps([model_name, model_params.to_dict()])))
    
