import torch
from torch import nn
import math
# a normal neural network
class DefaultNet(nn.Module):
  def __init__(self,width,depth,input_size,output_size):
        super(DefaultNet, self).__init__()
        layers = []
        layers.append(nn.Linear(input_size,width))
        layers.append(nn.ReLU())
        for i in range(depth-1):
          layers.append(nn.Linear(width,width))
          layers.append(nn.ReLU())
        layers.append(nn.Linear(width,output_size))
        self.network= nn.Sequential(*layers)
  def forward(self,f):
        return self.network(f)
  #every other layer is a relu layer, hence 2n-1
  #this gets the pre-relu activations
  def layer_output(self,f,n):
        for i in range(2*n - 1):
          f = self.network[i](f)
        return f


#Block Diagonal Things------------------------------------------------------------------

def create_block_diagonal_mask(size, block_size):
    # Create a block like [1, 1, ..., 1] of size block_size
    block = torch.ones(block_size, block_size, dtype=torch.float32)

    # Repeat block across the diagonal
    num_blocks = size // block_size  # Assuming size is an exact multiple of block_size
    mask = torch.block_diag(*[block] * num_blocks)

    return mask

class Block_Diagonal_Linear(nn.Module):
    def __init__(self, num_features, bias=True):
        super(Block_Diagonal_Linear, self).__init__()
        self.num_features = num_features
        weights = torch.zeros((num_features,num_features))
        bias = torch.zeros(num_features)
        for i in range(num_features//4):
            a = nn.Linear(4,4)
            weights[4*i:4*i+4,4*i:4*i+4] = a.weight.detach()
            bias[4*i:4*i+4] = a.bias.detach()
        self.weights = weights
        self.bias = bias

    def forward(self, input):
        return input@(create_block_diagonal_mask(self.num_features,4)*self.weights).T + self.bias


class BlockDiagonalNet(nn.Module):
  def __init__(self,width,depth,input_size,output_size):
        super(BlockDiagonalNet, self).__init__()
        layers = []
        layers.append(nn.Linear(input_size,width))
        layers.append(nn.ReLU())
        for i in range(depth-1):
          layers.append(Block_Diagonal_Linear(width))
          layers.append(nn.ReLU())
        layers.append(nn.Linear(width,output_size))
        self.network= nn.Sequential(*layers)
  def forward(self,f):
        return self.network(f)
  #every other layer is a relu layer, hence 2n-1
  #this gets the pre-relu activations
  def layer_output(self,f,n):
        for i in range(2*n - 1):
          f = self.network[i](f)
        return f