import torch
import torch.nn as nn
import torch.nn.functional as F

class TTTModule(nn.Module):
    """
    A class that defines a simple neural network with two linear layers and LayerNorm, 
    and includes a method for manual weight updates using an L2 self-supervised loss.
    """
    def __init__(self, input_dim):
        """
        Initializes the layers of the module.

        Args:
            input_dim (int): The input dimension of the data.
        """
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 4 * input_dim)  # First linear layer
        self.linear2 = nn.Linear(4 * input_dim, input_dim)  # Second linear layer
        self.layer_norm = nn.LayerNorm(input_dim)  # Layer normalization

    def forward(self, x):
        """
        Defines the forward pass of the network.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: The output of the forward pass after residual and LayerNorm.
        """
        residual = x  # Store the input for residual connection
        z1 = self.linear1(x)  # First linear transformation
        z1_gelu = F.gelu(z1, approximate="tanh")  # GELU activation
        z2 = self.linear2(z1_gelu)  # Second linear transformation
        out = residual + z2  # Add residual connection
        return self.layer_norm(out)  # Apply LayerNorm

    def manual_update(self, x, target, lr):
        """
        Manually updates the weights based on L2 self-supervised loss.

        Args:
            x (Tensor): Input tensor (B, T, D).
            target (Tensor): Target tensor (B, T, D).
            lr (float): Learning rate for the update.

        Returns:
            float: The loss value after the manual update.
        """
        eps = 1e-6  # Small epsilon for numerical stability

        # Forward pass with residual
        residual = x
        z1 = F.linear(x, self.linear1.weight, self.linear1.bias)  # (B, T, 4D)
        z1_gelu = F.gelu(z1, approximate="tanh")  # GELU activation
        z2 = F.linear(z1_gelu, self.linear2.weight, self.linear2.bias)  # (B, T, D)
        output = residual + z2  # Add residual connection

        # LayerNorm (post-residual)
        mu = output.mean(dim=-1, keepdim=True)
        var = output.var(dim=-1, keepdim=True, unbiased=False)
        std = (var + eps).sqrt()
        x_hat = (output - mu) / std
        y_norm = self.layer_norm.weight * x_hat + self.layer_norm.bias

        # Compute loss (MSE loss)
        loss = F.mse_loss(y_norm, target)

        grad_y = 2 * (y_norm - target) / target.numel()  # Gradient of the loss

        # LayerNorm gradients
        grad_x_hat = grad_y * self.layer_norm.weight
        D = x.shape[-1]  # Dimension size
        grad_out = (
            1.0 / D
            * (D * grad_x_hat - grad_x_hat.sum(dim=-1, keepdim=True) - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True))
            / std
        )
        grad_gamma = (grad_y * x_hat).sum(dim=(0, 1))  # Gradient of gamma (LayerNorm scale)
        grad_beta = grad_y.sum(dim=(0, 1))  # Gradient of beta (LayerNorm bias)

        # Backpropagate through the second linear layer
        grad_z2 = grad_out  # Gradient with respect to the second linear layer
        grad_W2 = grad_z2.view(-1, grad_z2.shape[-1]).T @ z1_gelu.view(-1, z1_gelu.shape[-1])  # Gradient of W2
        grad_b2 = grad_z2.sum(dim=(0, 1))  # Gradient of b2

        grad_z1_gelu = grad_z2 @ self.linear2.weight  # Gradient w.r.t z1_gelu
        grad_z1 = grad_z1_gelu * gelu_bwd(z1)  # Backprop through GELU

        # Backpropagate through the first linear layer
        grad_W1 = grad_z1.view(-1, grad_z1.shape[-1]).T @ x.view(-1, x.shape[-1])  # Gradient of W1
        grad_b1 = grad_z1.sum(dim=(0, 1))  # Gradient of b1

        # Perform manual weight updates
        with torch.no_grad():
            self.linear2.weight -= lr * grad_W2
            self.linear2.bias -= lr * grad_b2
            self.linear1.weight -= lr * grad_W1
            self.linear1.bias -= lr * grad_b1
            self.layer_norm.weight -= lr * grad_gamma
            self.layer_norm.bias -= lr * grad_beta

        return loss.item()

def gelu_bwd(x):
    """
    Computes the derivative of the GELU activation function.

    Args:
        x (Tensor): Input tensor to compute the gradient.

    Returns:
        Tensor: Gradient of GELU function.
    """
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))  # GELU approximation
    return 0.5 * x * (1 - tanh_out ** 2) * (0.79788456 + 0.1070322243 * x ** 2) + 0.5 * (1 + tanh_out)
