import torch


def softrank(M):
    """
    TODO: FIX
    Computes the softrank of a matrix, i.e. the square frobenius norm divided by square largest singular value.
    :param M: matrix M
    :return: softrank of M
    """
    M = torch.tensor(M, dtype=torch.float32)
    frobenius_norm = torch.linalg.matrix_norm(M, ord="fro") ** 2
    max_norm = torch.linalg.matrix_norm(M, ord=2)

    return frobenius_norm ** 2 / max_norm ** 2

def multihead_softrank(attn):
    softranks = [
        softrank(a) for a in attn
    ]
    return torch.tensor(softranks)

def layer_softrank(attn):
    softranks = [
        multihead_softrank(a) for a in attn
    ]
    return torch.stack(softranks)


def compute_attn_entropies(attn):
    """
    Computes the entropy of attention head probabilities at a layer.

    :param attn: attention values at a layer
        shape: n_head x n_token x n_token
    :return: entropy of attention scores
        shape: n_head x n_token
    """

    attn = attn.to(torch.float32)

    multiplied = -1 * torch.log(attn) * attn
    
    # removes nan values from log(0) * 0
    multiplied_nan_removed = torch.nan_to_num(multiplied, nan=0.0)

    # sum over last dimension, i.e. token-wise entropies
    entropy = torch.sum(multiplied_nan_removed, dim=-1)

    return entropy



def compute_head_avg_attn_entropies(attn):
    """
    Average entropies of each token across heads.

    :param attn: attention values at a layer
        shape: n_head x n_token x n_token
    :return: entropy of attention scores
        shape: n_token
    """
    entropy = compute_attn_entropies(attn)
    avg_entropy = torch.mean(entropy, dim=0)

    return avg_entropy

def compute_avg_attn_entropies(attn_list, stacked=False):
    avg = [
        compute_head_avg_attn_entropies(a) for a in attn_list
    ]
    if stacked:
        return torch.stack(avg)
    return avg


def compute_token_avg_attn_entropies(attn_list, normalized=False, drop_first=True):
    """
    Average attention head entropies across tokens and across heads at each layer.

    :param attn_list: list of attn heads at each layer
        - len(attn_list) == n_layer
        - attn_list[i] is n_head x n_token x n_token
    :param normalized: normalized by maximal possible entropy at each token
        - divide average across heads by maximal entropy for each token
    :param drop_first: averaging without first token
    :return: token_avg_entropies: average across tokens across heads at each layer
        - len(token_avg_entropies) == n_layer
    """
    n_layer = len(attn_list)
    n_head, n_token, _ = attn_list[0].shape

    

    avg_entropies = torch.stack(compute_avg_attn_entropies(attn_list))

    if normalized:
        norm_avg, token_avg = normalize_token_avg_attn_entropies(avg_entropies)
        return norm_avg, token_avg

    if drop_first:
        token_avg = torch.mean(avg_entropies[:, 1:], dim=-1)
    else:
        token_avg = torch.mean(avg_entropies, dim=-1)

    global_avg = torch.mean(token_avg)

    return token_avg, global_avg



def normalize_token_avg_attn_entropies(avg_entropies):

    avg_entropies = avg_entropies.to(torch.float32)

    n_token = avg_entropies.shape[1]

    norm_avg = avg_entropies[:, 1:].cpu() / torch.log(torch.arange(2, n_token + 1))

    token_avg = torch.mean(norm_avg, dim=-1)

    

    global_avg = torch.mean(token_avg)

    return token_avg, global_avg



def compute_avg_attn(attn_list, stacked=False, std=False):
    """
    Compute average across layers for each head
    :param attn_list: list of attn heads at each layer
        - len(attn_list) == n_layer
        - attn_list[i] is n_head x n_token x n_token
    :param stacked: return stacked tensor
    :return: avg_attn: list/tensor of average attention at each layer
        - len(avg_attn) == n_layer
        - attn_list[i] is n_token x n_token
    """
    attn_tensor = torch.stack(attn_list)
    avg_attn = torch.mean(attn_tensor, dim=1)

    if not stacked:
        avg_attn = list(avg_attn)

    if std:
        std_attn = torch.std(attn_tensor, dim=1)
        return avg_attn, std_attn

    return avg_attn


