import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import NamedTuple, Tuple
    
class WB_Batch(NamedTuple):
    weights: Tuple[torch.Tensor, ...]
    biases: Tuple[torch.Tensor, ...]

class BatchedMLP(nn.Module):
    def __init__(self, activation_fn=torch.tanh):
        super().__init__()
        self.activation_fn = activation_fn

    def forward(self, images: torch.Tensor, w_b: WB_Batch) -> torch.Tensor:
        x = images.view(images.shape[0], -1)
        x = x.to(w_b.weights[0].dtype)
        num_layers = len(w_b.weights)
        for i in range(num_layers):
            weights = w_b.weights[i].squeeze(-1)
            biases = w_b.biases[i].squeeze(-1)
            if i == 0:
                x = torch.einsum('ni,bio->bno', x, weights)
            else:
                x = torch.einsum('bni,bio->bno', x, weights)
            x = x + biases.unsqueeze(1)
            if i < num_layers - 1:
                x = self.activation_fn(x)
        return x
    
def calculate_l1_sum(w_b: WB_Batch) -> torch.Tensor:
    """Calculates the L1 sum for all weights and biases in a WB_Batch."""
    total_l1 = torch.tensor(0.0, device=w_b.weights[0].device)
    for param_tuple in (w_b.weights, w_b.biases):
        for param in param_tuple:
            total_l1 += torch.abs(param).sum()
    return total_l1