import torch

def tensor_to_size_in_bytes(tensor):
    return int(torch.prod(torch.tensor(tensor.shape), 0))*4

_g_mem = 0
_g_mem_activations = 0

_g_mem_dict = {'activations': []}
_g_max_activation = 0


def backward_hook(module, grad_input, grad_output):
    global _g_mem_activations, _g_mem_dict
    if module.__class__.__name__ == 'Identity':
        return
    if len(module._modules) == 0:
        for grad in grad_output:
            _g_mem_activations += tensor_to_size_in_bytes(grad)
            _g_mem_dict['activations'].append((module.__class__.__name__, tensor_to_size_in_bytes(grad)))
    else:
        _g_mem_dict['activations'].append((module.__class__.__name__, 0))

def forward_hook(module, input, output):
    global _g_max_activation
    if len(module._modules) == 0:
        _g_max_activation = max(_g_max_activation, tensor_to_size_in_bytes(output))

def sizeof_state_dict(state_dict):
    size_in_bytes = 0
    for key in state_dict:
        tensor_dim = 1
        for dim in state_dict[key].shape:
            tensor_dim *= dim
        size_in_bytes += 4*tensor_dim #conversion to bytes
    return size_in_bytes
    

def training_mem(Model, kwargs, input_shape, sd=None):
    global _g_mem, _g_mem_activations, _g_mem_dict, _g_max_activation
    _g_mem = 0
    _g_mem_activations = 0
    _g_mem_dict = {'activations': []}
    _g_max_activation = 0

    net = Model(**kwargs)

    if sd is not None:
        net.load_state_dict(sd)

    for module in net.modules():
        module.register_backward_hook(backward_hook)
        module.register_forward_hook(forward_hook)
    input = torch.rand(input_shape)

    out = net(input)
    out = out.sum()
    out.backward()

    size_of_loaded_parameters = sizeof_state_dict(net.state_dict())
    size_of_gradients = sizeof_state_dict({key: param for (key, param) in net.named_parameters() if param.requires_grad})

    mem_total =  _g_max_activation + _g_mem_activations + size_of_loaded_parameters + size_of_gradients

    return round(mem_total/(10**9), 4)

def training_mem_individual(Model, kwargs, input_shape, sd= None):
    global _g_mem, _g_mem_activations, _g_mem_dict, _g_max_activation
    _g_mem = 0
    _g_mem_activations = 0
    _g_mem_dict = {'activations': []}
    _g_max_activation = 0

    net = Model(**kwargs)

    if sd is not None:
        net.load_state_dict(sd)

    for module in net.modules():
        module.register_backward_hook(backward_hook)
        module.register_forward_hook(forward_hook)
    input = torch.rand(input_shape)

    out = net(input)
    out = out.sum()
    out.backward()

    size_of_loaded_parameters = sizeof_state_dict(net.state_dict())
    size_of_gradients = sizeof_state_dict({key: param for (key, param) in net.named_parameters() if param.requires_grad})

    mem_total =  _g_max_activation + _g_mem_activations + size_of_loaded_parameters + size_of_gradients

    return round(mem_total/(10**9), 4), _g_mem_dict