#!/usr/bin/python3

import torch
import scipy
import numpy as np


activation_functions = tuple(
    getattr(torch.nn, fct) for fct in torch.nn.modules.activation.__all__
)

def get_effective_rank(matrix, return_singular_values=False):
    S = torch.linalg.svdvals(matrix)
    if return_singular_values:
        singular_values = S.detach().clone()
    S /= torch.sum(S)
    erank = torch.e ** scipy.stats.entropy(S.detach())
    if return_singular_values:
        return np.nan_to_num(erank), singular_values
    return np.nan_to_num(erank)


def get_score_effective_rank(model, dataloader):
    hooks = []
    activations = []
    finished = [0]
    called = set()
    def get_activation(layer_id):
        def hook(model, input, output):
            called.add(layer_id)
            if isinstance(output, tuple):
                output = output[0]
            size = output.shape[-1]
            if output.shape[0] == 1:
                output = output.squeeze(0)
            if output.dim() > 2:
                output = torch.transpose(output, 1, 3).flatten(0, 2)
            activations[layer_id] = torch.cat((activations[layer_id], output), dim=0)
            if activations[layer_id].shape[0] >= activations[layer_id].shape[1]:
                start = np.random.randint(0, activations[layer_id].shape[0]//size - 1) * size if activations[layer_id].shape[0]//size - 1 > 0 else 0
                end = start + activations[layer_id].shape[1]
                activations[layer_id] = activations[layer_id][start:end]
                hooks[layer_id].remove()
                finished.append(layer_id)
        return hook
    

    layer_stack = [
        module
        for name, module in model.named_modules()
        if hasattr(module, "weight") or isinstance(module, activation_functions)
    ]
    for layer_id, layer in enumerate(layer_stack):
        activations.append(torch.tensor([]))
        hook = layer.register_forward_hook(get_activation(layer_id))
        hooks.append(hook)
    

    for X in dataloader:
        if torch.any(X["input_ids"] > 50254):
            continue
        model(X["input_ids"])
        # For NATSBench-SSS for some networks not all hooks will be called
        # Make sure we still break out early 
        if len(finished) == len(called):
            break

    score = 0.
    for activation in activations[1:]:
        if len(activation) == 0:
            continue
        activation_no_grad = activation.detach()
        score += get_effective_rank(activation_no_grad)
    print("DONE")
    return score

def get_average_score_effective_rank(model, dataloader, repetitions=1):
    scores = [get_score_effective_rank(model, dataloader) for _ in range(repetitions)]
    return np.mean(scores)

