import torch
import torch.nn as nn
import torch.nn.functional as F

# SA-HiPPO Layer (Spike-Aware HiPPO Layer)
class SAHiPPOLayer(nn.Module):
    def __init__(self, state_size, decay_rate):
        super(SAHiPPOLayer, self).__init__()
        self.state_size = state_size
        self.decay_rate = decay_rate  # Decay factor α

        # Initialize HiPPO matrices A and B
        self.A = nn.Parameter(torch.randn(state_size, state_size))
        self.B = nn.Parameter(torch.randn(state_size))

    def forward(self, x, delta_t):
        # Compute the decay matrix F(Δt)
        F_delta_t = torch.exp(-self.decay_rate * delta_t).unsqueeze(-1).unsqueeze(-1)
        A_S = self.A * F_delta_t  # Spike-Aware HiPPO matrix

        # Update state x
        x_next = torch.matmul(A_S, x.unsqueeze(-1)).squeeze(-1) + self.B
        return x_next

# FLAMES Convolution Layer
class FLAMESConvLayer(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size):
        super(FLAMESConvLayer, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, padding=kernel_size // 2)

        # State-space parameters
        self.A = nn.Parameter(torch.randn(output_channels, output_channels))
        self.B = nn.Parameter(torch.randn(output_channels))
        self.C = nn.Parameter(torch.randn(output_channels))

    def forward(self, x):
        # x: (batch_size, channels, height, width)
        x_conv = self.conv(x)  # Convolution operation

        # State-space update
        batch_size, channels, height, width = x_conv.size()
        
        # Flatten the spatial dimensions to apply the state-space updates
        x_flat = x_conv.view(batch_size, channels, -1)  # Shape: (batch_size, channels, height * width)

        # Apply state-space update
        x_state = torch.matmul(self.A, x_flat) + self.B.view(-1, 1)  # Shape: (batch_size, channels, height * width)

        # Output with C: element-wise multiplication, broadcast over the flattened spatial dimension
        y_flat = self.C.view(1, -1, 1) * x_state  # Shape: (batch_size, channels, height * width)

        # Reshape y_flat back to the original spatial dimensions
        y = y_flat.view(batch_size, channels, height, width)  # Shape: (batch_size, channels, height, width)
        return y

# Dendritic Attention Layer
class DendriticAttentionLayer(nn.Module):
    def __init__(self, input_channels, output_channels, tau_d_list, tau_s):
        super(DendriticAttentionLayer, self).__init__()
        self.num_dendrites = len(tau_d_list)
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        self.g_d = nn.Parameter(torch.ones(self.num_dendrites, 1, 1))  # Adjusted dimensions

        # Initialize dendritic currents and membrane potential
        self.register_buffer('i_d', torch.zeros(self.num_dendrites, output_channels))
        self.register_buffer('V', torch.zeros(output_channels))

        # Precompute alpha_d_list and beta
        alpha_d_list = [torch.exp(torch.tensor(-1.0 / tau_d)) for tau_d in tau_d_list]
        self.register_buffer('alpha_d_list', torch.tensor(alpha_d_list).view(self.num_dendrites, 1))

        beta = torch.exp(torch.tensor(-1.0 / tau_s))
        self.register_buffer('beta', torch.tensor(beta))

    def forward(self, x):
        batch_size = x.size(0)
        channels = x.size(1)
        height = x.size(2)
        width = x.size(3)

        x = self.conv(x)  # (batch_size, output_channels, H, W)
        output_channels = x.size(1)  # Number of output channels after convolution

        # Dendritic processing
        i_d_list = []
        x_mean = x.mean(dim=[2, 3])  # Mean over spatial dimensions, shape: (batch_size, output_channels)
        for idx in range(self.num_dendrites):
            alpha_d = self.alpha_d_list[idx].to(x.device)
            i_d_prev = self.i_d[idx].unsqueeze(0).expand(batch_size, -1)  # Expand to batch size
            i_d = alpha_d * i_d_prev + x_mean  # Shape: (batch_size, output_channels)
            i_d_list.append(i_d)
            self.i_d[idx] = i_d.mean(dim=0).detach()  # Update i_d buffer

        # Aggregate dendritic currents at the soma
        i_d_stack = torch.stack(i_d_list, dim=0)  # Shape: (num_dendrites, batch_size, output_channels)
        g_d = self.g_d.to(x.device)
        i_d_sum = torch.sum(g_d * i_d_stack, dim=0)  # Shape: (batch_size, output_channels)

        # Soma potential update
        beta = self.beta.to(x.device)
        V_prev = self.V.unsqueeze(0).expand(batch_size, -1)  # Expand to batch size
        V = beta * V_prev + i_d_sum  # No need to take mean over channels here

        # Spike generation
        V_th = 1.0  # Threshold potential
        spikes = (V >= V_th).float()  # Shape: (batch_size, output_channels)
        V = V * (1 - spikes)  # Reset potential where spikes occurred

        self.V = V.mean(dim=0).detach()  # Update V buffer

        # Reshape spikes to match x's dimensions for masking
        spikes = spikes.view(batch_size, output_channels, 1, 1)  # Make it broadcastable over spatial dimensions
        
        # Masked output, apply spikes to all spatial positions for each channel
        return spikes * x

# FLAMES Model
class FLAMESModel(nn.Module):
    def __init__(self, input_channels, num_classes, decay_rate, tau_d_list, tau_s):
        super(FLAMESModel, self).__init__()
        self.dendritic_layer = DendriticAttentionLayer(
            input_channels=input_channels,
            output_channels=64,
            tau_d_list=tau_d_list,
            tau_s=tau_s
        )
        self.FLAMES_conv1 = FLAMESConvLayer(64, 128, kernel_size=3)
        self.norm1 = nn.BatchNorm2d(128)  # Batch normalization layer for stability
        self.FLAMES_conv2 = FLAMESConvLayer(128, 256, kernel_size=3)
        self.norm2 = nn.BatchNorm2d(256)  # Batch normalization layer for stability

        # Fully connected output layer for classification
        self.fc = nn.Linear(256 * 8 * 8, num_classes)  # Assuming input image size is 32x32

    def forward(self, x):
        # x: (batch_size, channels, height, width)
        x = self.dendritic_layer(x)

        x = self.FLAMES_conv1(x)
        x = self.norm1(x)
        x = F.relu(x)

        x = self.FLAMES_conv2(x)
        x = self.norm2(x)
        x = F.relu(x)

        # Flatten the tensor for fully connected layer
        x = x.view(x.size(0), -1)  # Shape: (batch_size, 256 * 8 * 8)
        out = self.fc(x)
        return out
