import torch
import torch.nn as nn


def mlp_weight_damper(module, depth, factor):
    """Dampen MLP weights by a factor scaled by depth."""
    factor = factor ** (1 / (depth - 1))
    for _, layer in enumerate(module):
        if isinstance(layer, nn.Linear):
            # change init mode here

            layer.weight = torch.nn.Parameter(layer.weight * factor)
            layer.bias = torch.nn.Parameter(layer.bias * factor)


def positional_encoding_init(seq_len, d, n):
    """Initialize positional encoding matrix using sinusoidal patterns."""
    P = torch.zeros((seq_len, d))
    for k in range(seq_len):
        for i in torch.arange(int(d / 2)):
            period = ((d / 2) / (2 * torch.pi)) / (k + 1)
            phase = 0  # (2*torch.pi*(d/2))* k/(seq_len+1)
            P[k, 2 * i] = torch.sin((i / period) + phase)
            P[k, 2 * i + 1] = torch.cos((i / period) + phase)
    return torch.nn.Parameter(P)


def build_mlp(dim_in, dim_hid, dim_out, depth):
    """Build a multi-layer perceptron with ReLU activations."""
    modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
    for _ in range(depth - 2):
        modules.append(nn.Linear(dim_hid, dim_hid))
        modules.append(nn.ReLU(True))
    modules.append(nn.Linear(dim_hid, dim_out))
    return nn.Sequential(*modules)


class SkipMLP(nn.Module):
    """MLP with skip connection that adds input to output."""

    def __init__(self, dim_in, dim_out, nonlinear_layer):
        """Initialize SkipMLP with linear layer and nonlinear component."""
        super().__init__()
        self.linear_layer = nn.Linear(dim_in, dim_out)
        self.nonlinear_layer = nonlinear_layer

    def forward(self, x):
        """Forward pass with skip connection."""
        linear = self.linear_layer(x)
        out = self.nonlinear_layer(linear)
        out += linear  # Adding the skip connection
        return out


def build_mlp_with_linear_skipcon(
    dim_in, dim_hid, dim_out, depth, weight_damper_factor=0.1
):
    """Build an MLP with linear skip connections and weight dampening."""
    nonlinear_layer = build_mlp(dim_out, dim_hid, dim_out, depth - 1)
    mlp_weight_damper(nonlinear_layer, depth=depth, factor=weight_damper_factor)
    skip_mlp = SkipMLP(dim_in, dim_out, nonlinear_layer)

    return skip_mlp


def expand_kv_heads(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Repeat key-value heads to match query heads."""
    if n_rep == 1:
        return x
    b, kvh, t, d = x.shape
    x = x[:, :, None, :, :].expand(b, kvh, n_rep, t, d)
    return x.reshape(b, kvh * n_rep, t, d)
