#!/usr/bin/env python3

import sys, os

# Import PyTorch root package
import torch                        

# Import PyTorch layers, activations and more
import torch.nn.functional as F

from utils.logger import Logger

def torch_layers_info(model:torch.nn.Module):
    """
        Parameters:
            model (nn.Module): Neural Network model

        Returns:
            dict: Statistics about used standart torch modules inside. key: <module name> value: <used number> 
    """
    max_string_length = 0
    basic_modules = {}

    for module in model.modules():
        class_name = str(type(module)).replace("class ", "").replace("<'", "").replace("'>", "")

        # Skip Sequential models
        if class_name.find("torch.nn.modules.container.Sequential") == 0:
            continue

        max_string_length = max(max_string_length, len(class_name))
        if class_name not in basic_modules:
            basic_modules[class_name] = dict()
            basic_modules[class_name]["count"] = 0
            basic_modules[class_name]["parameters_size_bytes_cpu_train"]  = 0
            basic_modules[class_name]["parameters_size_bytes_gpu_train"]  = 0
            basic_modules[class_name]["parameters_size_bytes_cpu_frozen"] = 0
            basic_modules[class_name]["parameters_size_bytes_gpu_frozen"] = 0

            basic_modules[class_name]["parameters_size_numel_cpu_train"]  = 0
            basic_modules[class_name]["parameters_size_numel_gpu_train"]  = 0
            basic_modules[class_name]["parameters_size_numel_cpu_frozen"] = 0
            basic_modules[class_name]["parameters_size_numel_gpu_frozen"] = 0


        basic_modules[class_name]["count"] += 1
        size_in_bytes_cpu_train  = 0
        size_in_bytes_gpu_train  = 0
        size_in_bytes_cpu_frozen = 0
        size_in_bytes_gpu_frozen = 0

        size_in_numel_cpu_train  = 0
        size_in_numel_gpu_train  = 0
        size_in_numel_cpu_frozen = 0
        size_in_numel_gpu_frozen = 0

        for param in module.parameters():
            size_in_bytes = param.numel() * param.element_size()
            size_in_elements = param.numel()

            if param.device.type == "cpu":
                if param.requires_grad == True:
                    size_in_bytes_cpu_train += size_in_bytes
                    size_in_numel_cpu_train += size_in_elements
                else:
                    size_in_bytes_cpu_frozen += size_in_bytes
                    size_in_numel_cpu_frozen += size_in_elements

            else:
                if param.requires_grad == True:
                    size_in_bytes_gpu_train += size_in_bytes
                    size_in_numel_gpu_train += size_in_elements
                else:
                    size_in_bytes_gpu_frozen += size_in_bytes
                    size_in_numel_gpu_frozen += size_in_elements

        basic_modules[class_name]["parameters_size_bytes_cpu_train"]  += size_in_bytes_cpu_train
        basic_modules[class_name]["parameters_size_bytes_gpu_train"]  += size_in_bytes_gpu_train
        basic_modules[class_name]["parameters_size_bytes_cpu_frozen"] += size_in_bytes_cpu_frozen
        basic_modules[class_name]["parameters_size_bytes_gpu_frozen"] += size_in_bytes_gpu_frozen

        basic_modules[class_name]["parameters_size_numel_cpu_train"]  += size_in_numel_cpu_train
        basic_modules[class_name]["parameters_size_numel_gpu_train"]  += size_in_numel_gpu_train
        basic_modules[class_name]["parameters_size_numel_cpu_frozen"] += size_in_numel_cpu_frozen
        basic_modules[class_name]["parameters_size_numel_gpu_frozen"] += size_in_numel_gpu_frozen

    return basic_modules

def print_current_gpu_context(device, args):
    """Print current stream and blas handle for specific device"""
    if device == "cpu":
        return

    logger = Logger.get(args.run_id)
    current_stream = torch.cuda.current_stream(device)
    blas_handle = torch.cuda.current_blas_handle()
    logger.info(f"Current Steam: {current_stream}, BLAS handle: {hex(blas_handle)}")

def print_models_info(model:torch.nn.Module, args):
    """
        Parameters:
            model (nn.Module): Neural Network model
            args: Command line arguments

        Returns:
            None. All information is printed into stdout
    """
    logger = Logger.get(args.run_id)

    logger.info("----------------- Information about the model start ---------------------------------------------")
    logger.info('{0:44s} | {1:3s} | {2:s} | {3:s}'.format("Name", "Layers", "Learnable Parameters(Frozen)", "Learnable Parameters(Train)"))
    logger.info("-------------------------------------------------------------------------------------------------")

    layers_info = torch_layers_info(model)
    for layer, info in layers_info.items():
        logger.info(f'{layer:44s} | {info["count"]:6d} | {( (info["parameters_size_bytes_cpu_frozen"] + info["parameters_size_bytes_gpu_frozen"])/1024.0):8g} KBytes / {(info["parameters_size_numel_cpu_frozen"] + info["parameters_size_numel_gpu_frozen"])} elements'
                    +
                    f'| {((info["parameters_size_bytes_cpu_train"] + info["parameters_size_bytes_gpu_train"])/1024.0):8g} KBytes / {(info["parameters_size_numel_cpu_train"] + info["parameters_size_numel_gpu_train"])} elements')

    logger.info("-------------------------------------------------------------------------------------------------")
    logger.info("    Model class:" + str(type(model)).replace("class ", "").replace("<'", "").replace("'>", ""))
    logger.info("-------------------------------------------------------------------------------------------------")

#===================================================================================================================

def number_of_params(model:torch.nn.Module, skipFrozen:bool = True)->int:
    """
        Parameters:
            model (torch.nn.Module): Neural Network model

        Returns:
            integer: number of scalar parameters in the network to learn
    """
    total_number_of_scalar_parameters = 0
    for p in model.parameters():
        if skipFrozen and not p.requires_grad:
            continue

        total_items_in_param = 1
        for i in range(p.dim()):
            total_items_in_param = total_items_in_param * p.size(i)
        total_number_of_scalar_parameters += total_items_in_param
    return total_number_of_scalar_parameters

def set_params_to_zero(model, skipFrozen:bool = True, param_predicate = None):
    """
    Setup all model parameter to zero without tracking by autograd.

    This setup process is not tracking by autograd.

    Parameters:
        model (torch.nn.Module): Neural Network model
        param_predicate(function(i,param)): If none this function is used to understand should be setup this parameter or not
    """
    with torch.no_grad():
        if param_predicate is None:
            for p in model.parameters():

                if skipFrozen and not p.requires_grad:
                    continue

                p.zero_()
        else:
            for i, p in enumerate(model.parameters()):

                if skipFrozen and not p.requires_grad:
                    continue

                if param_predicate(i,p) == True:
                    p.zero_()

# Currently used. Once it will be, please be carefull with random generators states
def set_params_uniform_random(model, a = 0.0, b = 1.0, skipFrozen:bool = True, param_predicate = None):
    """
    Setup all model parameter independently uniformly at random U(a,b).

    This setup process is not tracking by autograd.

    Parameters:
        model (torch.nn.Module): Neural Network model
        a(float): 'a' parameter of distribution
        b(float): 'b' parameter of distribution
        param_predicate(function(i,param)): If none this function is used to understand should be setup this parameter or not
    """
    with torch.no_grad():
        if param_predicate is None:
            for p in model.parameters():
                if skipFrozen and not p.requires_grad:
                    continue
                p[:] = a + (b-a) * torch.rand_like(p)
        else:
            for i, p in enumerate(model.parameters()):
                if skipFrozen and not p.requires_grad:
                    continue

                if param_predicate(i, p) == True:
                    p[:] = a + (b-a) * torch.rand_like(p)

def get_buffers(model:torch.nn.Module):
    local_model_buffers = list()
    for buf in model.buffers(): 
        local_model_buffers.append(buf.detach().clone())
    return local_model_buffers

def set_buffers(model:torch.nn.Module, buffer_list:list):
    with torch.no_grad():
        local_model_buffers = list()
        for index, buf in enumerate(model.buffers()):
            buf.flatten(0)[:] = buffer_list[index].flatten(0)[:]

def get_params(model:torch.nn.Module, skipFrozen:bool = True, param_predicate = None):
    """
    Get all model parameters as a single dense vector.

    Get all model parameters as a single dense vector, if you are interesting only on a subset of parameter use param_predicate.

    Parameters:
        model (torch.nn.Module): Neural Network model
        param_predicate(function(i,param)): If none this function is used to understand should be setup this parameter or not

    Returns:
        torch.Tensor: all parameters in a form of a tensor
    """
    params = []

    # For optimization do not perform copy on each tennsor, intead of it make light copy of data. torch.cat(...) will produce new tensors

    if param_predicate is None:
        for p in model.parameters():
            if skipFrozen and not p.requires_grad:
                continue

            params.append(p.flatten(0).detach())      # Remove clone()
    else:
        for i, p in enumerate(model.parameters()):
            if skipFrozen and not p.requires_grad:
                continue

            if param_predicate(i, p) == True:
                params.append(p.flatten(0).detach())  # Remove clone()

    # Concatenates tensors along dim=0
    params_vector = torch.cat(tuple(params))

    return params_vector


def set_params(model:torch.nn.Module, parameters, skipFrozen:bool = True, param_predicate = None):
    """
    Set model parameters from a single dense vector.

    Set all model parameters from a single dense vector, if you are interesting only on a subset of parameter use param_predicate.
    This setup process is not tracking by autograd.

    Parameters:
        model (torch.nn.Module): Neural Network model
        parameters(torch.Tensor): Dense vector with parameters
        param_predicate(function(i,param)): If none this function is used to understand should be setup this parameter or not
    """
    with torch.no_grad():
        offset = 0
        if param_predicate is None:
            for i, p in enumerate(model.parameters()):
                if skipFrozen and not p.requires_grad:
                    continue

                sz = p.numel()
                p.flatten(0)[:] = parameters[(offset):(offset+sz)]
                offset += sz
        else:
            for i, p in enumerate(model.parameters()):
                if skipFrozen and not p.requires_grad:
                    continue

                if param_predicate(i, p) == True:
                    sz = p.numel()
                    p.flatten(0)[:] = parameters[(offset):(offset+sz)]
                    offset += sz


def get_gradient(model:torch.nn.Module, skipFrozen:bool = True):
    """
    Get all model gradient data as a single dense vector.

    Parameters:
        model (torch.nn.Module): Neural Network model

    Returns:
        torch.Tensor: all parameters in a form of a tensor
    """
    grads = []
    for p in model.parameters():
        if skipFrozen and not p.requires_grad:
            continue

        if p.grad is not None:
            grads.append(p.grad.flatten(0).detach())     # Remove clone()
        else:
            grads.append(torch.zeros_like(p).flatten(0))

    # Concatenates tensors along dim = 0
    grad_vec = torch.cat(tuple(grads))

    return grad_vec

def get_zero_gradient_compatible_with_model(model:torch.nn.Module, skipFrozen:bool = True):
    """
    Get zero vector with shape compatible with gradient data in a form of a single dense vector.

    Parameters:
        model (torch.nn.Module): Neural Network model

    Returns:
        torch.Tensor: zero vector with compatible shape
    """
    grads = []
    for p in model.parameters():
        if skipFrozen and not p.requires_grad:
            continue

        grads.append(torch.zeros_like(p).flatten(0))

    grad_vec = torch.cat(tuple(grads))
    return grad_vec

def add_to_gradient(model:torch.nn.Module, extra_grad, skipFrozen:bool = True):
    """
    Add to model gradient extra vector

    Parameters:
        model (nn.Module): Neural Network model
        extra_grad (torch.Tensor): Dense vector with gradients for all components
    """
    with torch.no_grad():
        offset = 0
        for i, p in enumerate(model.parameters()):

            if skipFrozen and not p.requires_grad:
                continue

            if p.grad is None:
                p.grad = torch.zeros_like(p)

            sz = p.grad.numel()
            p.grad.flatten(0)[:] += extra_grad[(offset):(offset+sz)]
            offset += sz

def set_gradient(model:torch.nn.Module, grad, skipFrozen:bool = True):
    """
    Set model gradient

    Parameters:
        model (nn.Module): Neural Network model
        grad (torch.Tensor): Dense vector with gradients for all components
    """
    with torch.no_grad():
        offset = 0
        for i, p in enumerate(model.parameters()):
            if skipFrozen and not p.requires_grad:
                continue

            if p.grad is None:
                p.grad = torch.empty_like(p)

            sz = p.grad.numel()
            p.grad.flatten(0)[:] = grad[(offset):(offset+sz)]
            offset += sz

def l2_norm_of_gradient_m(model:torch.nn.Module, skipFrozen:bool = True):
    value = 0.0
    for p in model.parameters():
        if skipFrozen and not p.requires_grad:
            continue

        if p.grad is not None:
            value += (p.grad**2).sum().item()
    return value**0.5

def l2_norm_of_vec(grad):
    return ((grad**2).sum().item())**0.5

def turn_off_batch_normalization_and_dropout(model:torch.nn.Module):
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, torch.nn.BatchNorm2d)  or isinstance(m, torch.nn.BatchNorm3d) or isinstance(m, torch.nn.Dropout):
            m.eval()
#===========================================================================================================================
# Unittests for launch please use: "pytest -v mutils.py" 
# https://docs.pytest.org/en/stable/getting-started.html


def test_get_set_for_model():
    hidden_features_layer_1 = 20
    hidden_features_layer_2 = 22
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=1, out_features=hidden_features_layer_1, bias = False),
        torch.nn.Linear(in_features=hidden_features_layer_1, out_features=hidden_features_layer_2, bias = False),
        torch.nn.Linear(in_features=hidden_features_layer_2, out_features=1, bias = False),
    )

    assert number_of_params(model) == hidden_features_layer_1 + hidden_features_layer_1 * hidden_features_layer_2 + hidden_features_layer_2
    class Empty: pass
    args = Empty()
    args.run_id = "test"
    Logger.setup_logging()

    print_models_info(model, args)

    model.train(True)
    for p in model.parameters():
        assert p.grad is None

    z = (10.0 - model(torch.Tensor([[3]])))**2
    z.backward()
    for p in model.parameters():
        assert p.grad is not None
        assert p.grad.shape == p.shape

    g = get_gradient(model)
    assert abs(l2_norm_of_gradient_m(model) - l2_norm_of_vec(g)) < 1.0e-4
    assert g.numel() == number_of_params(model)
    set_gradient(model, 2.0 * g)

    assert abs(l2_norm_of_gradient_m(model) - 2.0 * l2_norm_of_vec(g)) < 1.0e-4

def test_grad_addition_for_model():
    hidden_features_layer_1 = 20
    hidden_features_layer_2 = 22
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=1, out_features=hidden_features_layer_1, bias = False),
        torch.nn.Linear(in_features=hidden_features_layer_1, out_features=hidden_features_layer_2, bias = True),
        torch.nn.Linear(in_features=hidden_features_layer_2, out_features=1, bias = True),
    )
    n1 = number_of_params(model)
    model[0].requires_grad_(False)
    n2 = number_of_params(model)
    assert n1 == n2
    z = (10.0 - model(torch.Tensor([[3]])))**2
    z.backward()
    g = get_gradient(model)
    # Verify that requires_grad_(False) is correctly working with get_gradient() functionality
    assert l2_norm_of_vec(g[0:hidden_features_layer_1]) < 1.0e-4

    # Verify that gradient has correct number of items
    assert g.numel() == 1*hidden_features_layer_1 + hidden_features_layer_1 * hidden_features_layer_2 + hidden_features_layer_2 * 1 + hidden_features_layer_2 + 1
    assert g.dim() == 1
    get_zero_gradient_compatible_with_model(model)

    assert abs(l2_norm_of_gradient_m(model) - l2_norm_of_vec(g)) < 1.0e-4
    set_gradient(model, g)
    add_to_gradient(model, torch.ones(n1))
    g1 = get_gradient(model)
    assert g1[0].item() == 1.0
    add_to_gradient(model, -torch.ones(n1))
    g2 = get_gradient(model)
    assert l2_norm_of_vec(g2 - g) < 1.0e-5
    assert l2_norm_of_vec(g2 - g1) > 1.0e-5

    assert len(torch_layers_info(model)) == 1

def test_get_set_params_and_grad_inf():
    hidden_features_layer_1 = 20
    hidden_features_layer_2 = 22
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=1, out_features=hidden_features_layer_1, bias = False),
        torch.nn.Linear(in_features=hidden_features_layer_1, out_features=hidden_features_layer_2, bias = True),
        torch.nn.Linear(in_features=hidden_features_layer_2, out_features=1, bias = True),
    )
    model[0].requires_grad_(False)
    z = (10.0 - model(torch.Tensor([[3]])))**2
    z.backward()
    g1 = get_gradient(model)
    add_to_gradient(model, g1)
    g2 = get_gradient(model)
    assert l2_norm_of_vec(g2 - 2*g1) < 1.0e-5
    
    g3 = get_zero_gradient_compatible_with_model(model)
    assert l2_norm_of_vec(g3) < 1.0e-5
    assert g3.size() == g2.size()
    set_params_to_zero(model)
    assert l2_norm_of_vec(get_params(model)) < 1.0e-5
    set_params_uniform_random(model)
    assert l2_norm_of_vec(get_params(model)) > 1.0
    set_params_to_zero(model)
    z = (10.0 - model(torch.Tensor([[0]])))**2
    z.backward()
    assert z.item() == 100.0

#===========================================================================================================================
