import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import lava.lib.dl.slayer as slayer
from typing import Callable, Optional, Tuple

# Implementation of the SRM for slayer training using Lava-DL


from slayer_model.kernels import Kernel


class SRMLayer(nn.Module):
    """
    Lava-DL enhanced SRM Layer with SLAYER delay gradients

    Implements: V^{l}_i(t) = sum_j w_ij [phi_k^i * S_j^{l-1}](t) + w^out_i[phi_q^i * S_i^{l}](t) + b_i
    where S_i^{l} = 1 if V_i^{l}(t) > 1 else 0

    Enhanced with SLAYER's temporal credit assignment and delay gradients.
    """

    def __init__(
            self,
            n_neurons: int,
            n_inputs: int,
            phi_k: Kernel,
            phi_q: Kernel = None,
            max_delays_in: int = 50,
            k_centers: Optional[torch.Tensor] = None,
            k_widths: Optional[torch.Tensor] = None,
            q_widths: Optional[torch.Tensor] = None,
            in_weights: Optional[torch.Tensor] = None,
            out_weights: Optional[torch.Tensor] = None,
            bias_init: Optional[torch.Tensor] = None,
            dt: float = 0.01,
    ):
        """
        Args:
            n_neurons: Number of neurons in this layer
            n_inputs: Number of input channels
            phi_k: Input kernel function (callable)
            phi_q: Output kernel function (callable) - None for feedforward only
            k_centers: Initial delays (tau) for input kernels
            k_widths: Initial widths (sigma) for input kernels
            q_widths: Initial widths (sigma) for output kernels
            in_weights: Initial input weights (n_inputs x n_neurons)
            out_weights: Initial output weights (n_neurons,)
            bias_init: Initial bias values (n_neurons,)
            dt: Time step size
        """
        super().__init__()

        self.n_neurons = n_neurons
        self.n_inputs = n_inputs
        self.phi_k = phi_k
        self.phi_q = phi_q
        self.spike = phi_q is not None
        self.dt = dt


        # Use SLAYER's delay mechanism for proper temporal credit assignment
        self.slayer_input_delay = slayer.axon.Delay(sampling_time=1, max_delay=max_delays_in, grad_scale=1)
        if k_centers is not None:
            with torch.no_grad():
                self.slayer_input_delay.delay.data = k_centers
            self.slayer_input_delay.init = True
        else:
            with torch.no_grad():
                self.slayer_input_delay.delay.data = torch.zeros(size=(n_neurons,))
            self.slayer_input_delay.init = True

        # Kernel parameters with non-negativity constraints
        # Using clamp instead of softplus for cleaner Lava-DL integration
        self.k_widths = nn.Parameter(
            k_widths if k_widths is not None
            else torch.full((n_neurons,), 15 * dt)
        )

        # Recurrent kernel parameters
        if self.spike:
            self.q_widths = nn.Parameter(
                q_widths if q_widths is not None
                else torch.full((n_neurons,), 15 * dt)
            )
            self.out_weights = nn.Parameter(
                out_weights if out_weights is not None
                else torch.full((n_neurons,), 1.0)
            )

        # Main weights and bias - ALWAYS create these for gradient tracking
        self.in_weights = nn.Parameter(
            in_weights if in_weights is not None
            else torch.normal(0.5, 0.5, size=(n_inputs, n_neurons))
        )

        self.bias = nn.Parameter(
            bias_init if bias_init is not None
            else torch.zeros(n_neurons)
        )

    def get_regularisation_weights(self, neuron_wise=False):
        if neuron_wise:
            return torch.square(self.in_weights).sum(dim=0)
        else:
            return torch.square(self.in_weights).sum()

    def load_from_sswim(self, npz_path):
        sswim_weights = np.load(npz_path)

        k_centers = torch.from_numpy(sswim_weights['k_delays']).float()
        with torch.no_grad():
            self.slayer_input_delay.delay.data = k_centers
        self.slayer_input_delay.init = True

        k_widths = torch.from_numpy(sswim_weights['k_widths']).float()
        self.k_widths = nn.Parameter(
            k_widths
        )

        in_weights = torch.from_numpy(sswim_weights['in_weights']).float()
        self.in_weights = nn.Parameter(
            in_weights
        )

        bias = torch.from_numpy(sswim_weights['bias']).float()
        self.bias = nn.Parameter(
            bias
        )

        if self.spike:
            q_widths = torch.from_numpy(sswim_weights['q_widths']).float()
            self.q_widths = nn.Parameter(
                q_widths
            )
            out_weights = torch.from_numpy(sswim_weights['out_weights']).float()
            self.out_weights = nn.Parameter(
                out_weights
            )



    def forward(self, s_in: torch.Tensor, return_v: bool=False) -> torch.Tensor:
        """
        Forward pass with SLAYER enhanced gradients

        Args:
            s_in: Input spike trains (batch_size, n_inputs, n_timesteps)

        Returns:
            Output spikes or voltage (batch_size, n_neurons, n_timesteps)
        """
        batch_size, n_inputs_actual, n_timesteps = s_in.shape
        device = s_in.device

        # Always use manual processing to ensure gradient flow
        # Apply input weights to maintain gradient tracking
        current_3d = torch.einsum('bjt,ji->bit', s_in, self.in_weights)

        # Apply custom SRM kernels
        input_voltage = self._apply_kernels(
            current_3d, self.phi_k, self.slayer_input_delay, self.k_widths, device
        )

        # Add bias (ensures gradient connection)
        input_voltage = input_voltage + self.bias[None, :, None]

        input_voltage = self.slayer_input_delay(input_voltage)

        if self.spike:
            return self._forward_recurrent(
                input_voltage, device, return_v
            )
        else:
            return input_voltage

    def compute_kernels(self, phi, delays, widths, device):
        # Apply constraints to ensure non-negativity
        widths_pos = torch.clamp(widths, min=self.dt)

        # Determine kernel duration
        max_width = torch.max(widths_pos)
        kernel_duration = max_width + (torch.max(delays.delay) if delays is not None else 0)
        n_steps = int(torch.ceil(kernel_duration / self.dt).item())

        # Create time vector
        t_kernel = torch.arange(0, n_steps * self.dt, self.dt, device=device)

        # Compute kernel values: phi((t - tau) / sigma)
        widths_expanded = widths_pos.unsqueeze(1)  # (n_neurons, 1)
        t_expanded = t_kernel.unsqueeze(0)  # (1, n_steps)

        args = t_expanded / widths_expanded
        kernels = phi(args)
        return kernels, kernels.shape[-1]

    def _apply_kernels(self, signal: torch.Tensor, phi: Callable,
                       delays: slayer.axon.Delay or None, widths: torch.Tensor,
                       device: torch.device) -> torch.Tensor:
        """Apply custom kernels to signal"""
        batch_size, n_channels, n_timesteps = signal.shape

        kernels, n_steps_k = self.compute_kernels(
            phi, delays, widths, device
        )

        # Prepare kernels for grouped convolution
        kernels = torch.flip(kernels, dims=(-1,)).unsqueeze(1)  # (n_neurons, 1, n_steps_k)

        # Apply padding and convolution
        total_padding = n_steps_k - 1
        signal_padded = F.pad(signal, (total_padding, 0), mode='constant', value=0)

        # Grouped convolution
        output = F.conv1d(
            signal_padded,
            kernels,
            groups=n_channels,
            padding=0
        )

        return output

    def _forward_recurrent(self, input_voltage: torch.Tensor, device: torch.device, return_v:bool=False) -> torch.Tensor:
        """Handle recurrent forward pass with SLAYER enhancements"""
        batch_size, n_neurons, n_timesteps = input_voltage.shape

        # Initialize recurrent components
        q_widths_clamped = torch.clamp(self.q_widths, min=self.dt)
        q_kernels, n_steps_q = self.compute_kernels(
            self.phi_q, None, q_widths_clamped, device
        )

        # Buffer for recurrent contributions
        future_vout = torch.zeros((batch_size, n_neurons, n_steps_q), device=device, requires_grad=True)

        # Output storage
        #spikes = torch.zeros((batch_size, n_neurons, n_timesteps), device=device, requires_grad=True)
        #voltages = torch.zeros((batch_size, n_neurons, n_timesteps), device=device, requires_grad=True)
        voltages_list = []
        spikes_list = []

        # Time-stepped simulation (necessary for recurrent dynamics)
        for t in range(n_timesteps):
            # Input voltage at time t
            v_in_t = input_voltage[:, :, t]

            # Recurrent voltage contribution
            v_out_t = future_vout[:, :, 0]

            # Total membrane potential
            v_total = v_in_t + v_out_t
            voltages_list.append(v_total)

            # Generate spikes
            spike_t = slayer.spike.Spike.apply(v_total, 1, 1, 1, False, 0, 1)

            spikes_list.append(spike_t)

            # Update recurrent buffer
            future_vout = torch.roll(future_vout, shifts=-1, dims=2)
            future_vout[:, :, -1] = 0

            # Add recurrent contribution using manual weights for gradient flow
            if hasattr(self, 'out_weights'):
                spike_contribution = -spike_t * torch.clamp(self.out_weights, min=0.5)[None, :]
            else:
                # Default behavior for recurrent connections
                spike_contribution = -spike_t * 1.0

            future_vout += spike_contribution[:, :, None] * q_kernels[None, :, :]

        spikes = torch.stack(spikes_list, dim=2)

        if return_v:
            return spikes, torch.stack(voltages_list, dim=2)
        return spikes

    def init_fluct_rg(self, x: torch.Tensor, mu_u: float, xi: float):
        with torch.no_grad():
            nu = x.mean() / self.dt
            n = self.n_inputs
            eps = self.phi_k.area(self.k_widths.mean())
            eps_hat = self.phi_k.area_square(self.k_widths.mean())

            snve = torch.sqrt(eps_hat / (n * nu))
            sig_term = 1 / (xi/eps * snve + 1)
            if mu_u > sig_term: # Check and correct admissibility by shifting bias
                self.bias.data[:] = 2 * (sig_term - mu_u)
                mu_u = 2 * sig_term - mu_u

            mu_w = mu_u / (n * nu * eps)

            sig_w = torch.sqrt(
                ((1 - mu_u) / xi) ** 2 / (n * nu * eps_hat) - mu_w ** 2
            )

            # main weights
            self.in_weights.data = torch.normal(
                mu_w.item(), sig_w.item(),
                size=(self.n_inputs, self.n_neurons),
                device=x.device
            )

    def load_weights_from_numpy(self, weight_dict: dict):
        """
        Load weights from numpy arrays (matching your interface)

        Args:
            weight_dict: Dictionary with keys like 'in_weights', 'bias', 'out_weights'
        """
        with torch.no_grad():
            if 'in_weights' in weight_dict:
                weight_tensor = torch.from_numpy(weight_dict['in_weights']).float()
                # Always load into main weights
                self.in_weights.data = weight_tensor

            if 'bias' in weight_dict:
                self.bias.data = torch.from_numpy(weight_dict['bias']).float()

            if 'out_weights' in weight_dict and self.spike:
                weight_tensor = torch.from_numpy(weight_dict['out_weights']).float()
                if hasattr(self, 'out_weights'):
                    self.out_weights.data = weight_tensor

                if self.use_slayer_delays and hasattr(self, 'slayer_recurrent_synapse'):
                    # Load into SLAYER recurrent synapse
                    # Create diagonal matrix for self-connections
                    diag_weights = torch.diag(weight_tensor)
                    self.slayer_recurrent_synapse.weight.data = diag_weights

