import math

import torch
import torch.nn as nn
from torch.nn.functional import normalize


@torch.no_grad()
def orthogonal_matrix_chunk(cols, device=None):
    unstructured_block = torch.randn((cols, cols), device=device)
    q, r = torch.linalg.qr(unstructured_block, mode="reduced")
    return q.t()  # [cols, cols]


@torch.no_grad()
def gaussian_orthonormal_random_matrix(nb_rows, nb_columns, scaling=0, device=None):
    """create 2D Gaussian orthonormal matrix"""
    nb_full_blocks = int(nb_rows / nb_columns)

    block_list = []

    for _ in range(nb_full_blocks):
        q = orthogonal_matrix_chunk(nb_columns, device=device)
        block_list.append(q)

    remaining_rows = nb_rows - nb_full_blocks * nb_columns
    if remaining_rows > 0:
        q = orthogonal_matrix_chunk(nb_columns, device=device)
        block_list.append(q[:remaining_rows])

    final_matrix = torch.cat(block_list)

    if scaling == 0:
        multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
    elif scaling == 1:
        multiplier = math.sqrt((float(nb_columns))) * torch.ones(
            (nb_rows,), device=device
        )
    else:
        raise ValueError(f"Invalid scaling {scaling}")

    result = torch.diag(multiplier) @ final_matrix
    result = normalize(result, p=2.0, dim=1)

    return result


class OrthonormalRandomFeaturesPE(nn.Module):
    def __init__(self, num_rows, num_columns, scaling=0, device=None):
        super().__init__()

        Q = gaussian_orthonormal_random_matrix(
            num_rows, num_columns, scaling=scaling, device=device
        )
        self.register_buffer("Q", Q)

    def forward(self):
        """
        x: LongTensor of shape (batch_size,)
        """
        return self.Q
