import torch
import torch.nn as nn
from typing import Literal


def unsqueeze_like(tensor: torch.Tensor, like: torch.Tensor, dim=0):
    """
    Unsqueeze last dimensions of tensor to match another tensor's number of dimensions.

    Args:
        tensor (torch.Tensor): tensor to unsqueeze
        like (torch.Tensor): tensor whose dimensions to match
        dim: int: starting dim, default: 0.
    """
    n_unsqueezes = like.ndim - tensor.ndim
    if n_unsqueezes < 0:
        raise ValueError(f"tensor.ndim={tensor.ndim} > like.ndim={like.ndim}")
    elif n_unsqueezes == 0:
        return tensor
    else:
        return tensor[dim * (slice(None),) + (None,) * n_unsqueezes]

class NormedLinear(nn.Module):
    def __init__(self, in_channels: int, 
                 out_channels: int,
                 norm: Literal['none', 'bn', 'ln'] = 'ln'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.norm_type = norm

        self.lin = nn.Linear(in_channels, out_channels)
        if norm == 'none':
            self.norm = nn.Identity()
        elif norm == 'bn':
            raise NotImplementedError()
        elif norm == 'ln':
            self.norm = nn.LayerNorm(out_channels)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        if hasattr(self.norm, 'reset_parameters'):
            self.norm.reset_parameters()

    def forward(self, x):
        return self.norm(self.lin(x))
