import logging
import torch
from torch_geometric.data import InMemoryDataset

logger = logging.getLogger(__name__)


@torch.no_grad
def bts(
    dataset: InMemoryDataset,
    num: int,
    window_size: int,
) -> InMemoryDataset:
    if num == -1:
        return dataset

    logger.info(f"Computing {num} most important eigenvectors. {window_size=}")

    # Mask out things we can use
    mask = dataset.train_mask | dataset.val_mask
    # mask = dataset.train_mask
    V = dataset.eigenvecs[mask].to("cuda")
    y = dataset.y[mask].to("cuda")

    V = torch.nn.functional.normalize(V, dim=0)

    # Prepare a one-hot class assignment matrix
    # shape = (N, num_classes)
    if y.dim() > 1:
        y = y.squeeze()

    num_nodes = len(y)
    num_classes = len(y.unique())
    C = torch.zeros((num_nodes, num_classes), device="cuda")
    C[torch.arange(num_nodes), y] = 1

    # Compute energy for each frequency
    y_spec = (V.T @ C).norm(dim=-1).cpu()
    y_spec = smooth_and_normalize_spectrum(y_spec, window_size)

    # Pick `num` frequencies with most class-aligned energy
    indices = y_spec.argsort(descending=True)[:num]
    dataset._data.eigenvals = dataset._data.eigenvals[indices]
    dataset._data.eigenvecs = dataset._data.eigenvecs[:, indices]
    return dataset


def smooth_and_normalize_spectrum(x: torch.Tensor, window_size: int) -> torch.Tensor:
    # Apply smoothing to y_spec using a moving average filter
    ans = torch.zeros_like(x)
    total_energy = torch.square(x).sum()
    for i in range(len(x)):
        # Calculate window boundaries with edge handling
        window_start = max(0, i - window_size // 2)
        window_end = min(len(x), i + window_size // 2 + 1)
        # Compute the moving average
        ans[i] = (torch.square(x[window_start:window_end]).mean() / total_energy) * len(x)
    return ans
