import torch
from torch import nn
#import torch.nn.functional
import numpy as np

class Phase1Layer(nn.Module):
    def __init__(self,d,width):
        super().__init__()
        self.d = d
        self.width = width
        self.activation = nn.ReLU()
        self.weights = torch.tensor(np.zeros((4*width,4*width)),dtype = torch.float32)
    # x has peak locations
    # e has the weights for the sum neuron
    # y is the data
    #w is the input directions (not used but makes the code easier to write)
    def forward(self, x, e, w, b, y):
        with torch.no_grad():
          self.weights = self.weights * 0
        indices = torch.arange(self.width)
        #sum neuron
        self.weights[4*indices,4*indices] = 1
        self.weights[4*indices,4*indices+1] = 1
        self.weights[4*indices,4*indices+2] = 1
        if self.d>1: #we want the V shapes alone (without bias) the first time we sum
          self.weights[4*indices,4*indices+3] = -1

        #v neuron 1
        common_factor1 = 1-x[indices,self.d+1]
        self.weights[4*indices+1,4*indices+1] = -1*common_factor1
        self.weights[4*indices+1,4*indices+2] = -1*common_factor1
        self.weights[4*indices+1,4*indices+3] = x[indices,self.d]*common_factor1

        #v neuron 2
        common_factor2 = (x[indices,self.d]/(1-x[indices,self.d]))*(1-x[indices,self.d+1])
        self.weights[4*indices+2,4*indices+1] = common_factor2
        self.weights[4*indices+2,4*indices+2] = common_factor2
        self.weights[4*indices+2,4*indices+3] = -1*x[indices,self.d]*common_factor2

        #bias neurons
        self.weights[4*indices+3,4*indices+3] = (1-x[indices,self.d+1])*x[indices,self.d]

        return self.activation(y@self.weights.T)

class Phase1Input(nn.Module):
    def __init__(self,width,input_size):
        super().__init__()
        self.width = width
        self.activation = nn.ReLU()
        weights = torch.tensor(np.ones((4*width,input_size)),dtype = torch.float32)
        weights[0::4] = 0 #every index == 0 mod 4 is 0
        weights[3::4]=0 #make bias neuron weights (3 mod 4) 0 as well
        self.weights = weights
        bias = torch.tensor(np.zeros(4*width),dtype = torch.float32)
        bias[3::4] = 1
        bias[0::4]=0
        self.bias = bias
    def forward(self, x, e, w, b, y):
        new_bias = self.bias.clone() #big autodiff related bugs without this
        new_weights = self.weights.clone()
        indices = np.arange(self.width)
        # [:, None] is a reshaping operation so the division broadcasts right
        new_weights[4*indices+1] = -0.5*w[indices]/x[indices,0][:,None]
        new_weights[4*indices+2] = 0.5*w[indices]/(1-x[indices,0][:,None])
        new_bias[4*indices+1] = (1 - 0.5*(1-b[indices])/x[indices,0])
        new_bias[4*indices+2] = (1 - 0.5*(1+b[indices])/(1-x[indices,0]))

        return self.activation(y@new_weights.T+new_bias)

class Phase1Output(nn.Module):
    def __init__(self,width,output_size):
        super().__init__()
        self.width = width
        self.output_size = output_size
        self.weights = torch.tensor(np.zeros((output_size,4*width)),dtype = torch.float32)
    def forward(self, x, e, w, b, y):
        with torch.no_grad():
          self.weights = self.weights * 0

        basis_index = torch.arange(self.width).unsqueeze(0)
        output_index = torch.arange(self.output_size).unsqueeze(1)
        self.weights[output_index,4*basis_index] = e[output_index,basis_index]
        self.weights[output_index,4*basis_index+1] = e[output_index,basis_index]
        self.weights[output_index,4*basis_index+2] = e[output_index,basis_index]
        self.weights[output_index,4*basis_index+3] = -1*e[output_index,basis_index]

        return y@self.weights.T

class Phase1(nn.Module):
  def __init__(self,width,depth,input_size,output_size):
        super(Phase1, self).__init__()
        self.x = nn.Parameter(torch.tensor(np.random.uniform(0.05,0.95,(width,depth+1)),dtype=torch.float32))
        #these are the final weights for summing the basis functions
        emphasis = torch.tensor(np.random.randn(output_size,width),dtype= torch.float32)
        emphasis/=emphasis.norm(dim=1, keepdim=True)
        self.emphasis = nn.Parameter(emphasis)
        #these are the initial weights for orienting the basis functions
        directions = torch.randn(width,input_size)
        directions /= directions.norm(dim=1, keepdim=True)
        self.directions = nn.Parameter(directions)
        #extra bias for each direction
        bias = torch.zeros(width)
        self.bias = nn.Parameter(bias)
        self.layers = []
        self.layers.append(Phase1Input(width,input_size))
        for i in range(1,depth):
          self.layers.append(Phase1Layer(i,width))
        self.layers.append(Phase1Output(width,output_size))
  def forward(self,f):
        with torch.no_grad():
          self.x[:]=self.x.clamp(0.01,0.99)
        for l in self.layers:
          f = l.forward(self.x,self.emphasis,self.directions,self.bias,f) #I made all the forward signitures the same to make this part easy
        return f
  def layer_output(self,f,n):
        for i in range(n):
          f = self.layers[i].forward(self.x,self.emphasis,self.directions,self.bias,f)
        return f
  

# Conversion to standard parameteriation for fine-tuning

class Phase2Layer(nn.Module):
    def __init__(self,x,width,d):
        super().__init__()
        self.activation = nn.ReLU()
        weights = torch.tensor(np.zeros((4*width,4*width)),dtype = torch.float32)
        indices = torch.arange(width)
        #sum neuron
        weights[4*indices,4*indices] = 1
        weights[4*indices,4*indices+1] = 1
        weights[4*indices,4*indices+2] = 1
        if d>1: #we want the V shapes alone (without bias) the first time we sum
          weights[4*indices,4*indices+3] = -1

        #v neuron 1
        common_factor1 = 1-x[indices,d+1]
        weights[4*indices+1,4*indices+1] = -1*common_factor1
        weights[4*indices+1,4*indices+2] = -1*common_factor1
        weights[4*indices+1,4*indices+3] = x[indices,d]*common_factor1

        #v neuron 2
        common_factor2 = (x[indices,d]/(1-x[indices,d]))*(1-x[indices,d+1])
        weights[4*indices+2,4*indices+1] = common_factor2
        weights[4*indices+2,4*indices+2] = common_factor2
        weights[4*indices+2,4*indices+3] = -1*x[indices,d]*common_factor2

        #bias neurons
        weights[4*indices+3,4*indices+3] = (1-x[indices,d+1])*x[indices,d]

        self.weights = nn.Parameter(weights)

    def forward(self, x):
        return self.activation(x@self.weights.T)

class Phase2Input(nn.Module):
    def __init__(self,x,w,b,width,input_size):
        super().__init__()
        self.activation = nn.ReLU()
        weights = torch.tensor(np.ones((4*width,input_size)),dtype = torch.float32)
        weights[0::4] = 0 #every index == 0 mod 4 is 0
        weights[3::4]=0 #make bias neuron weights (3 mod 4) 0 as well
        bias = torch.tensor(np.zeros(4*width),dtype = torch.float32)
        bias[3::4] = 1
        bias[0::4]=0
        indices = np.arange(width)
        # [:, None] is a reshaping operation so the division broadcasts right
        weights[4*indices+1] = -0.5*w[indices]/x[indices,0][:,None]
        weights[4*indices+2] = 0.5*w[indices]/(1-x[indices,0][:,None])
        bias[4*indices+1] = (1 - 0.5*(1-b[indices])/x[indices,0])
        bias[4*indices+2] = (1 - 0.5*(1+b[indices])/(1-x[indices,0]))
        self.weights = weights
        self.bias = bias
    def forward(self, x):
        return self.activation(x@self.weights.T+self.bias)

class Phase2Output(nn.Module):
    def __init__(self,e,width,output_size):
        super().__init__()
        weights = torch.tensor(np.zeros((output_size,4*width)),dtype = torch.float32)

        basis_index = torch.arange(width).unsqueeze(0)
        output_index = torch.arange(output_size).unsqueeze(1)
        weights[output_index,4*basis_index] = e[output_index,basis_index]
        weights[output_index,4*basis_index+1] = e[output_index,basis_index]
        weights[output_index,4*basis_index+2] = e[output_index,basis_index]
        weights[output_index,4*basis_index+3] = -1*e[output_index,basis_index]

        self.weights = nn.Parameter(weights)
    def forward(self, x):
        return x@self.weights.T

class Phase2(nn.Module):
  def __init__(self,model): # uses peaks and scales from phase 2
        super(Phase2, self).__init__()
        self.layers = []
        e = model.emphasis.detach()
        x = model.x.detach()
        w = model.directions.detach()
        b = model.bias.detach()
        width = x.size()[0]
        depth = x.size()[1]-1
        input_size = w.size()[1]
        output_size = e.size()[0]
        self.layers = []
        self.layers.append(Phase2Input(x,w,b,width,input_size))
        for i in range(1,depth):
          self.layers.append(Phase2Layer(x,width,i))
        self.layers.append(Phase2Output(e,width,output_size))
        self.network = nn.Sequential(*self.layers)
  def forward(self,f):
        return self.network(f)
  def layer_output(self,f,n):
        for i in range(n):
          f = self.layers[i].forward(f)
        return f
  

#Block diagonal stuff here ---------------------------------------------------------

def create_block_diagonal_mask(size, block_size):
    # Create a block like [1, 1, ..., 1] of size block_size
    block = torch.ones(block_size, block_size, dtype=torch.float32)

    # Repeat block across the diagonal
    num_blocks = size // block_size  # Assuming size is an exact multiple of block_size
    mask = torch.block_diag(*[block] * num_blocks)

    return mask

class Phase2LayerBlockDiagonal(nn.Module):
    def __init__(self,x,width,d):
        super().__init__()
        self.activation = nn.ReLU()
        weights = torch.tensor(np.zeros((4*width,4*width)),dtype = torch.float32)
        indices = torch.arange(width)
        #sum neuron
        weights[4*indices,4*indices] = 1
        weights[4*indices,4*indices+1] = 1
        weights[4*indices,4*indices+2] = 1
        if d>1: #we want the V shapes alone (without bias) the first time we sum
          weights[4*indices,4*indices+3] = -1

        #v neuron 1
        common_factor1 = 1-x[indices,d+1]
        weights[4*indices+1,4*indices+1] = -1*common_factor1
        weights[4*indices+1,4*indices+2] = -1*common_factor1
        weights[4*indices+1,4*indices+3] = x[indices,d]*common_factor1

        #v neuron 2
        common_factor2 = (x[indices,d]/(1-x[indices,d]))*(1-x[indices,d+1])
        weights[4*indices+2,4*indices+1] = common_factor2
        weights[4*indices+2,4*indices+2] = common_factor2
        weights[4*indices+2,4*indices+3] = -1*x[indices,d]*common_factor2

        #bias neurons
        weights[4*indices+3,4*indices+3] = (1-x[indices,d+1])*x[indices,d]

        self.weights = nn.Parameter(weights)

    def forward(self, x):
        return self.activation(x@(self.weights*create_block_diagonal_mask(self.weights.shape[0],4)).T)

class Phase2BlockDiagonal(nn.Module):
  def __init__(self,model): # uses peaks and scales from phase 2
        super(Phase2BlockDiagonal, self).__init__()
        self.layers = []
        e = model.emphasis.detach()
        x = model.x.detach()
        w = model.directions.detach()
        b = model.bias.detach()
        width = x.size()[0]
        depth = x.size()[1]-1
        input_size = w.size()[1]
        output_size = e.size()[0]
        self.layers = []
        self.layers.append(Phase2Input(x,w,b,width,input_size))
        for i in range(1,depth):
          self.layers.append(Phase2LayerBlockDiagonal(x,width,i))
        self.layers.append(Phase2Output(e,width,output_size))
        self.network = nn.Sequential(*self.layers)
  def forward(self,f):
        return self.network(f)
  def layer_output(self,f,n):
        for i in range(n):
          f = self.layers[i].forward(f)
        return f