# the following function is used to perform ablations to identify the best layers for eigenscore implementation

import torch
import numpy as np
from typing import List, Literal

def compute_layerwise_eigenscores(
    hidden_states: List[torch.Tensor],
    num_tokens: List[int],
    token_strategy: Literal["last", "avg_generated", "all_tokens"],
    short: bool = False
):
    """
    Computes per-layer eigenscore using different token aggregation strategies.

    Args:
        hidden_states: List of hidden states at each timestep. 
                       Each element has shape (num_layers, batch_size, 1, hidden_dim)
        num_tokens: List of lengths per sequence in batch (for padding removal)
        token_strategy: Strategy for token aggregation:
                        - "last": last generated token per sequence
                        - "avg_generated": average over generated tokens
                        - "all_tokens": concatenate all token vectors
        short: If True, only compute scores every 4 layers (including final layer)

    Returns:
        A list of (layer_index, eigenscore) tuples.
    """
    assert len(hidden_states) > 1, "Not enough hidden states"

    # num_layers = hidden_states[0].shape[0]
    num_layers = len(hidden_states[0])  # Instead of hidden_states[0].shape[0]
    batch_size = hidden_states[0][0].shape[0]
    hidden_dim = hidden_states[0][0].shape[-1]

    print(f"Number of layers: {num_layers}")

    if short:
        layer_indices = list(range(0, num_layers, 8))
        if (num_layers - 1) not in layer_indices:
            layer_indices.append(num_layers - 1)
    else:
        layer_indices = range(num_layers)

    results = []
    alpha = 1e-3

    for layer_idx in layer_indices:
        if token_strategy == "last":
            last_embeddings = torch.zeros(batch_size, hidden_dim).to("cuda")
            for i in range(batch_size):
                last_embeddings[i] = hidden_states[num_tokens[i] - 2][layer_idx][i, 0, :]
        
        elif token_strategy == "avg_generated":
            last_embeddings = torch.zeros(batch_size, hidden_dim).to("cuda")
            for i in range(batch_size):
                count = 0
                for t in range(len(hidden_states) - 1):
                    if t + 1 > num_tokens[i] - 1:
                        continue
                    last_embeddings[i] += hidden_states[t + 1][layer_idx][i, 0, :]
                    count += 1
                if count > 0:
                    last_embeddings[i] /= count


        elif token_strategy == "all_tokens":
            all_embeddings = []
            if len(hidden_states)<2:
                return 0, "None"
            for t in range(len(hidden_states) - 1):
                token_embs = hidden_states[t + 1][layer_idx][:, 0, :]  # (batch, dim)
                all_embeddings.append(token_embs)
            last_embeddings = torch.cat(all_embeddings, dim=0)  # (batch * num_tokens, dim)
            last_embeddings = last_embeddings[:, ::40]  # Downsample to reduce dimensionality

        else:
            raise ValueError(f"Unknown token strategy: {token_strategy}")

        # Compute covariance and eigenscore
        cov = torch.cov(last_embeddings.T).cpu().numpy().astype(float)
        _, s, _ = np.linalg.svd(cov + alpha * np.eye(cov.shape[0]))
        eigenscore = np.mean(np.log10(s))
        print(f"Layer {layer_idx} eigenscore: {eigenscore}")

        results.append((layer_idx, eigenscore))

    return results
