import torch
import torch.nn.functional as F
import torch.nn.init as init

from torch import nn
from torch.nn.utils import parametrize
from torchdyn.core import NeuralODE
from typing import Dict, Optional, Any
from metrics import r2_score, regression_r2_score

def get_activation_function(name: Optional[str] = None) -> nn.Module:
    """Get activation function by name."""
    if name is None:
        return nn.Identity()
    elif name == 'tanh':
        return nn.Tanh()
    elif name == 'relu':
        return nn.ReLU()
    elif name == 'leaky_relu':
        return nn.LeakyReLU()
    elif name == 'sigmoid':
        return nn.Sigmoid()
    elif name == 'softmax':
        return nn.Softmax(dim=-1)
    else:
        raise ValueError(f"Unknown activation function: {name}")

# SoftPlus Parameterization function
class SoftplusParameterization(nn.Module):
    def forward(self, X):
        return nn.functional.softplus(X)

class RNN(nn.Module):
    def __init__(self, cell):
        super().__init__()
        self.cell = cell

    def forward(self, input, h_0):
        time = 0

        hidden = h_0
        states = []
        # input (650, 70, 1)
        # hidden (650, 50)
        hidden_input = []
        hidden_output = []
        influences = []
        
        for input_step in input.transpose(0, 1):
            hidden_input.append(hidden) # hidden: (16, 3)

            if type(self.cell) == nn.GRUCell or type(self.cell) == nn.RNNCell:
                
                hidden = self.cell(input_step, hidden)
            else:
                hidden, temp = self.cell(input_step, hidden, time)
                
                #NOTE: Currently hardcoded for Wang dataset
                if temp is not None:
                    influences.append(temp)
            hidden_output.append(hidden)    
            states.append(hidden)

            time += 1
        states = torch.stack(states, dim=1)

        hidden_input = torch.stack(hidden_input, dim=1)
        hidden_output = torch.stack(hidden_output, dim=1)
        
        return states, hidden, hidden_input, hidden_output, influences

# Masked Linear Layer for Force Percent experiments
class MaskedLinear(nn.Linear):
    def __init__(self,
                 in_features,
                 out_features,
                 force_percent=None,
                 neurons=None,
                 pop_latent_sizes=None,
                 bias=True,
                 device='cpu'):
        # Initialize the base class
        super(MaskedLinear, self).__init__(in_features, out_features, bias)

        # Assert attributes are not None
        assert force_percent is not None, "force_percent must be provided"
        assert neurons is not None, "neurons must be provided"
        assert pop_latent_sizes is not None, "pop_latent_sizes must be provided"

        # Create mask
        self.device = device
        self.mask = self.generate_mask(force_percent, neurons, pop_latent_sizes)

    def generate_mask(self, force_percent, neurons, pop_latent_sizes):
        """
        NOTE: This implementation is not like the original implementation by
        Jiyul. (Logic is a little different, TODO: Come back to this if we want
        to keep using it).
        """
        # Initialize mask with zeros
        mask = torch.zeros(
            (self.out_features, self.in_features), dtype=torch.bool,
            device=self.device
        )

        # Generate mask based on force_percent and population sizes
        latent_start = 0
        for neuron, pop_idx in neurons.items():
            # Get latent end idx for population
            latent_end = latent_start + pop_latent_sizes[neuron]
            # Get start and end indices for population
            pop_start, pop_end = pop_idx
            print(f"Neuron: {neuron}, Pop Start: {pop_start}, Pop End: {pop_end}, ")
            # Set the mask for this population to True
            mask[
                pop_start:pop_start+int((pop_end-pop_start)*force_percent),
                latent_start:latent_end
            ] = True
            
            latent_start = latent_end
        
        return mask

    def forward(self, input):
        """
        NOTE: Masking in the forward pass means a registered hook is not needed.
        """
        # Apply the mask to the weights
        masked_weight = self.weight * self.mask
        return F.linear(input, masked_weight, self.bias)
     

class MLPLayer(nn.Module):
    """
    """
    def __init__(self, input_dim, output_dim, bias=True, force_percent=None,
                 neurons=None, pop_latent_sizes=None, monotonic=False,
                 device='cpu'):
        
        # Initialize the base class
        super(MLPLayer, self).__init__()  
        
        # Create normal linear layer or masked linear layer
        # based on force_percent
        if force_percent is None:
            self.log_weight = nn.Linear(input_dim, output_dim, bias=bias)
        else:
            self.log_weight = MaskedLinear(
                input_dim, output_dim, bias=bias, force_percent=force_percent,
                neurons=neurons, pop_latent_sizes=pop_latent_sizes,
                device=device
            )

        # Initialize weights
        nn.init.xavier_normal_(self.log_weight.weight)
        
        if monotonic:
            # Register SoftPlus parameterization
            parametrize.register_parametrization(
                self.log_weight, 'weight', SoftplusParameterization()
            )   
    
    def initialize_monotonic_params(self, init_value=0):
        """
        Initialize the weights of the layer. Mean parameter of initialization
        is sensitive to whether the layer is monotonic or not.
        """
        nn.init.normal_(self.log_weight.weight, init_value, 0.3)

    def forward(self, z):
        return self.log_weight(z)


class AbstractLatentSAE(nn.Module):
    def __init__(
        self,
        config: Dict[str, Any],
    ):
        super().__init__()
        self.input_size = config['input_size']
        self.encoder_size = config['encoder_size']
        self.latent_size = config['latent_size']
        self.learning_rate = config['lr']
        self.weight_decay = config['weight_decay']
        self.dropout_rate = config['dropout']
        self.neurons = config['neurons']
        self.nonlinear_readout = config.get('nonlinear_readout', False)
        self.points_per_group = config.get('points_per_group', 20)
        self.epochs_per_group = config.get('epochs_per_group', 500)
        self.force_percent = config.get('force_percent', None)
        
        # Get readout activation function if specified
        if self.nonlinear_readout:
            nonlinear_activation = config.get('readout_nonlinearity', 'tanh')
            self.readout_act = get_activation_function(nonlinear_activation)
        else:
            # set to identity function
            self.readout_act = nn.Identity()
        
        # Store population-specific latent sizes
        self.population_latent_sizes = config.get('population_latent_sizes') or {pop: 1 for pop in self.neurons}

        # Create single shared encoder for all neurons across populations
        self.encoder = nn.GRU(
            input_size=self.input_size,
            hidden_size=self.encoder_size,
            batch_first=True,
            bidirectional=False if config['causal_model'] else True
        )
        # Linear mapping from encoder output to latent space
        if config['causal_model']:
            ic_lin_in_dim = self.encoder_size
        else:
            ic_lin_in_dim = self.encoder_size * 2

        self.ic_linear = MLPLayer(ic_lin_in_dim,
                                  self.latent_size,
                                  bias=True)
    
        # --- Consolidated Readout Layer ---
        self.readout_type = config.get('readout_type', 'separate')
        
        if self.readout_type == 'monotonic':
            self.readout = MLPLayer(
                input_dim=self.latent_size,
                output_dim=self.input_size,
                bias=True,
                force_percent=self.force_percent,
                neurons=self.neurons,
                pop_latent_sizes=self.population_latent_sizes,
                monotonic=True,
                device=config.get('device', 'cpu')
            )
        
        elif self.readout_type == 'linear':
            self.readout = MLPLayer(
                input_dim=self.latent_size,
                output_dim=self.input_size,
                bias=True,
                force_percent=self.force_percent,
                neurons=self.neurons,
                pop_latent_sizes=self.population_latent_sizes
            )
        
        elif self.readout_type == 'separate':
            # NOTE: Separate monotonic readout achieved with montonic
            # readout with force_percent 1.0

            # Create Module Dict for separate readouts
            self.readout = nn.ModuleDict({})
            # Loop through neurons to create separate readouts
            for neuron, output_size in self.neurons.items():
                # Get population-specific latent size
                pop_latent_size = self.population_latent_sizes[neuron]
                # Create a separate readout for each neuron
                self.readout[neuron] = MLPLayer(
                    input_dim=pop_latent_size,
                    output_dim=output_size[1] - output_size[0],
                    bias=True,
                    monotonic=False
                )

        elif self.readout_type == 'separate_nonlinear':
            #TODO: Add option to make monotonic
            self.readout = nn.ModuleDict({})

            for neuron, output_size in self.neurons.items():
                pop_latent_size = self.population_latent_sizes[neuron]
                self.readout[f'{neuron}_{0}'] = MLPLayer(
                    input_dim=pop_latent_size,
                    output_dim=50,
                    bias=True
                )
                self.readout[f'{neuron}_{1}'] = MLPLayer(
                    input_dim=50,
                    output_dim=output_size[1] - output_size[0],
                    bias=True
                )
        else:
            raise ValueError(f"Unknown readout_type: {self.readout_type}")

        # Readout layer norm layer
        self.readout_layer_norm = nn.LayerNorm(self.latent_size)

        # Instantiate dropout
        self.dropout = nn.Dropout(p=self.dropout_rate)

    def forward(self, data, unit=None):
        # Run data through encoder to get initial conditions
        _, h_n = self.encoder(data)

        # Combine output from fwd and bwd encoders
        h_n = torch.cat([*h_n], -1)

        # Add dropout before IC linear projection
        h_n = self.dropout(h_n)
        # Apply linear layer to get initial conditions
        ic = self.ic_linear(h_n)

        return ic

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )
        return optimizer
    
    def _apply_readout(self, latents):
        """Applies the appropriate readout based on the configuration."""
        # Normalize input
        latents = self.readout_layer_norm(latents)

        if self.readout_type in ['monotonic', 'linear']:
            return self.readout(latents)
        
        elif self.readout_type == 'separate':
            logrates = []
            latent_idx = 0
            for neuron in self.neurons:
                pop_latent_size = self.population_latent_sizes[neuron]
                # Extract the latent dimensions for this population
                pop_latents = latents[:, :, latent_idx:latent_idx + pop_latent_size]
                logrates.append(self.readout[neuron](pop_latents))
                
                latent_idx += pop_latent_size

            return torch.cat(logrates, dim=-1)
        
        elif self.readout_type == 'separate_nonlinear':
            logrates = []
            latent_idx = 0
            for neuron in self.neurons:
                pop_latent_size = self.population_latent_sizes[neuron]
                # Extract the latent dimensions for this population
                pop_latents = latents[:, :, latent_idx:latent_idx + pop_latent_size]
                # logrates.append(self.readout[neuron](pop_latents))
                out = self.readout[f'{neuron}_0'](pop_latents)
                logrates.append(self.readout[f'{neuron}_1'](self.readout_act(out)))
                latent_idx += pop_latent_size

            return torch.cat(logrates, dim=-1)
            
        else:
            raise ValueError(f"Unknown readout_type: {self.readout_type}")

    def _shared_step(self, batch, batch_ix, split):
        spikes, rates, latents, _ = batch
        # Pass data through the model
        pred_logrates, pred_latents = self.forward(spikes)
        # Prepare the data for metrics
        pred_rates = torch.exp(pred_logrates)
        early_rates, mid_rates, late_rates = torch.chunk(rates, 3, dim=1)
        early_pred_rates, mid_pred_rates, late_pred_rates = torch.chunk(
            pred_rates, 3, dim=1
        )
        early_latents, mid_latents, late_latents = torch.chunk(latents, 3, dim=1)
        early_pred_latents, mid_pred_latents, late_pred_latents = torch.chunk(
            pred_latents, 3, dim=1
        )
        # Compute the results
        results = {
            f"{split}/r2_observ": r2_score(pred_rates, rates),
            f"{split}/r2_observ/early": r2_score(early_pred_rates, early_rates),
            f"{split}/r2_observ/middle": r2_score(mid_pred_rates, mid_rates),
            f"{split}/r2_observ/late": r2_score(late_pred_rates, late_rates),
            f"{split}/r2_latent": regression_r2_score(latents, pred_latents),
            f"{split}/r2_latent/early": regression_r2_score(
                early_latents, early_pred_latents
            ),
            f"{split}/r2_latent/middle": regression_r2_score(
                mid_latents, mid_pred_latents
            ),
            f"{split}/r2_latent/late": regression_r2_score(
                late_latents, late_pred_latents
            ),
        }
        self.log_dict(results)
        # Compute the weighted loss
        loss_all = F.poisson_nll_loss(
            pred_logrates, spikes, full=True, reduction="none"
        )
        # Incrementally consider more points in the loss
        total_points = loss_all.shape[1]
        group_number = int(self.current_epoch / self.epochs_per_group) + 1
        num_points = min(group_number * self.points_per_group, total_points)
        self.log(f"{split}/num_points", float(num_points))
        # Compute weighted loss
        loss = torch.mean(loss_all[:, :num_points, :])
        self.log(f"{split}/loss", loss)
        return loss
    
    def training_step(self, batch, batch_ix):
        return self._shared_step(batch, batch_ix, "train")

    def validation_step(self, batch, batch_ix):
        return self._shared_step(batch, batch_ix, "valid")


class DiscreteTimeModel(AbstractLatentSAE):
    def __init__(self, cell: nn.Module, config:dict, noise:int=0):
        super().__init__(config=config)
        self.cell = cell
        self.neurons = config['neurons']
        self.device = config['device']
        if config['causal_model']:
            self.pred_horizon = int(config['window_size']-config['ic_window_size'])
        
        # Create decoder
        self.decoder = RNN(cell)

        # Store readout_type as primary control (heatmap is derived from this)
        self.readout_type = config.get('readout_type', 'separate')
        self.heatmap = config['heatmap']  # This should be set based on readout_type

        self.noise = noise
        self.stimulated_populations = config.get('stimulated_populations')

    def forward(self, data, timebin=None, checkpoint=False, bin=None, external_input=None):
        # Run through Encoder
        ic_drop = super().forward(data)
        
        # Prepare input to the dynamics (external input or noise/zeros)
        B, T, _ = data.shape
        pred_T = self.pred_horizon if hasattr(self, 'pred_horizon') else T
        

        if external_input is not None:
            # Ensure provided external input matches prediction horizon length
            inp = external_input.to(self.device)
            if inp.shape[1] != pred_T:
                if inp.shape[1] > pred_T:
                    inp = inp[:, :pred_T, :]
                else:
                    pad = torch.zeros((B, pred_T - inp.shape[1], inp.shape[2]),
                                      device=self.device, dtype=inp.dtype)
                    inp = torch.cat([inp, pad], dim=1)
            input_placeholder = inp
        else:
            if self.noise is not None and self.noise > 0:
                input_placeholder = torch.randn((B, pred_T, 1), device=self.device)
            else:
                input_placeholder = torch.zeros((B, pred_T, 1), device=self.device)

        # Unroll the decoder
        latents, _, input_hidden, output_hidden, influence = self.decoder(
            input_placeholder, ic_drop
        )

        # Apply the consolidated readout
        logrates = self._apply_readout(latents)
        
        return ic_drop, logrates, latents, input_hidden, output_hidden, influence
    
class ContinuousTimeModel(AbstractLatentSAE):
    def __init__(
        self,
        cell: nn.Module,
        config:dict,
        noise:int = 0,
        mlp_hidden_dims: list[int] = [128],
    ):
        super().__init__(config=config)
        self.cell = cell
        self.latent_size = config['latent_size']
        self.device = config['device']

        # Define prediction horizon
        if config['causal_model']:
            self.pred_horizon = int(config['window_size']-config['ic_window_size'])
        else:
            self.pred_horizon = int(config['window_size'])

        # Define the NeuralODE decoder and readout network
        self.decoder = NeuralODE(self.cell)  # continuous

    def forward(self, data, timebin=None, checkpoint=False, bin=None):
        # Run through Encoder
        ic_drop = super().forward(data)

        # Evaluate the NeuralODE
        t_span = torch.linspace(0, 1, self.pred_horizon)  #torch.Size([5])

        _, latents = self.decoder(ic_drop, t_span)
        latents = latents.transpose(0, 1)
        # Map decoder state to data dimension

        # Apply the consolidated readout
        logrates = self._apply_readout(latents)

        return ic_drop, logrates, latents, None, None, None
    
    
class MiniMLPDynamics(nn.Module):
    def __init__(self, mlp_hidden_dim=128, bias=True, neurons={}, noise=0,
                 target='E1', softplus=True, latent_sizes={}, act_func='tanh',
                 input=0, stimulated_populations=None):
        super().__init__()

        self.neurons = neurons
        self.noise = noise
        self.softplus = softplus
        self.target = target
        self.bias = bias
        self.latent_sizes = latent_sizes or {pop: 1 for pop in neurons}
        self.total_latent_size = sum(self.latent_sizes.values())
        self.input = input
        self.act_func = act_func
        # Populations allowed to receive external input (None -> allow none)
        # Expect list of population names matching keys of `neurons`
        self.stimulated_populations = stimulated_populations
        # Get the latent size for this target population
        self.target_latent_size = self.latent_sizes.get(target)

        self.pop_layer_norms = nn.ModuleDict({
            neuron: nn.LayerNorm(self.latent_sizes[neuron]) for neuron in self.neurons
        })

        self.pop_mlps = nn.ModuleDict({
            neuron: nn.Sequential(
                MLPLayer(
                    input_dim=self.latent_sizes[neuron],
                    output_dim=mlp_hidden_dim,
                    bias=self.bias,
                    monotonic=self.softplus
                ),
                get_activation_function(act_func),
                MLPLayer(
                    input_dim=mlp_hidden_dim,
                    output_dim=self.target_latent_size,
                    bias=self.bias,
                    monotonic=self.softplus
                )
            ) for neuron in self.neurons
        })

        self.target_mlp = nn.Sequential(
            MLPLayer(
                input_dim=self.target_latent_size,
                output_dim=mlp_hidden_dim,
                bias=self.bias,
                monotonic=False
            ),
            get_activation_function(act_func),
            MLPLayer(
                input_dim=mlp_hidden_dim,
                output_dim=self.target_latent_size,
                bias=self.bias,
                monotonic=False
            )
        )

        if self.input and self.input > 0:
            self.input_mlp = nn.Sequential(
                MLPLayer(
                    input_dim=self.input,
                    output_dim=mlp_hidden_dim,
                    bias=self.bias,
                    monotonic=False
                ),
                MLPLayer(
                    input_dim=mlp_hidden_dim,
                    output_dim=self.target_latent_size,
                    bias=self.bias,
                    monotonic=False
                )
            )
        else:
            self.input_mlp = None

    def forward(self, x):
        # Initialize the output list and influence dictionary
        outputs = []
        influence = {}

        # [C(E1) + C(E2) - C(I) + target]
        latent_start = 0
        for neuron in self.neurons:  
            # Get the start and end indices for latents of the current neuron
            latent_end = latent_start + self.latent_sizes[neuron]         
            # Get specific latent features and normalize
            x_latent = x[:, latent_start:latent_end]
            x_latent = self.pop_layer_norms[neuron](x_latent)
            # If the neuron is the target, apply the target-specific MLP
            if self.target == neuron:
                # Apply target-specific MLP
                x_a = self.target_mlp(x_latent)

                outputs.append(x_a)
                influence[f'Addition'] = x_a

            # Apply the population-specific MLP
            x_n = self.pop_mlps[neuron](x_latent)

            # If the neuron is inhibitory, negate the output
            # If it is linear, don't apply softplus (for comparisons)
            if self.act_func == None:
                if "I" in neuron:
                    outputs.append(-x_n)
                    influence[neuron] = -x_n
                else:
                    outputs.append(x_n)
                    influence[neuron] = x_n
            else:
                if "I" in neuron:
                    outputs.append(-F.softplus(x_n))
                    influence[neuron] = -F.softplus(x_n)
                else:
                    outputs.append(F.softplus(x_n))
                    influence[neuron] = F.softplus(x_n)

            latent_start = latent_end  # Update start for next population

            # External input contribution: only to stimulated/target population
            if self.input_mlp is not None:
                # Allow input if this target was stimulated OR if we're training with noise input
                input_present = (
                    (self.stimulated_populations is not None and self.target in self.stimulated_populations)
                    or (self.noise is not None and self.noise > 0)
                )
                if input_present:
                    input_start = self.total_latent_size
                    input_end = input_start + self.input
                    x_input = x[:, input_start:input_end]
                    x_impulse = self.input_mlp(x_input)
                    outputs.append(x_impulse)
                    influence['external_input'] = x_impulse

        # Sum all contributions, preserve the target latent size dimensionality
        # Each output now has shape (batch, target_latent_size)
        total_output = sum(outputs)  # Sum preserves the last dimension 
        return total_output, influence

# NOTE: Delete this later, MLP Dynamics is the same, but w/ identity activation
class LinearDynamicsCell(nn.RNNCell):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        mlp_hidden_dim: int = 128,
        bias: bool = True,
        dt: float = 0.1,
    ):
        super().__init__(input_size, hidden_size, bias)
        self.dt = dt

        # Layer norm
        self.layer_norm = nn.LayerNorm(hidden_size)

        if self.dt:
            self.dynamics_model = nn.Sequential(
                MLPLayer(hidden_size+input_size, mlp_hidden_dim, bias=bias),
                MLPLayer(mlp_hidden_dim, hidden_size, bias=bias)
            )
        else:
            self.dynamics_model = nn.Sequential(
                MLPLayer(hidden_size, mlp_hidden_dim, bias=bias),
                MLPLayer(mlp_hidden_dim, hidden_size, bias=bias)
            )

    def forward(self, input, hidden=None, time=None):
        if hasattr(self, "dynamics_model"):
            if hidden is not None:
                hidden = self.layer_norm(hidden)
                # Discrete solver: input is x_t, hidden is h_t
                state = torch.cat([hidden, input], dim=1)
            else:
                # Continuous solver: input is h_t
                input = self.layer_norm(input)
                state = input

            derivative = self.dynamics_model(state)
        else:
            derivative = super().forward(input, hidden)

        if self.dt and hidden is not None:
            # Discrete solver expects the next state h_{t+1}
            output = hidden + self.dt * derivative
            return output, None
        else:
            # Continuous solver expects the derivative dh/dt
            return derivative
    

class MLPDynamicsCell(nn.RNNCell):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool = True,
        dt: float = 0.1,
        mlp_hidden_dims: list[int] = [128],
        mlp_activation: str = 'tanh'
    ):
        super().__init__(input_size, hidden_size, bias)
        self.dt = dt

        # Layer norm
        self.layer_norm = nn.LayerNorm(hidden_size)

        if len(mlp_hidden_dims) > 0:
            layers = []
            if self.dt:
                layer_in_dim = hidden_size + input_size
            else:
                layer_in_dim = hidden_size

            for layer_out_dim in mlp_hidden_dims:
                first_layer = MLPLayer(layer_in_dim, layer_out_dim, bias=bias)

                layers.extend([first_layer, get_activation_function(mlp_activation)])
                layer_in_dim = layer_out_dim
            second_layer = MLPLayer(layer_in_dim, hidden_size, bias=bias)

            self.dynamics_model = nn.Sequential(*layers, second_layer)

    def forward(self, input, hidden=None, time=None):
        if hasattr(self, "dynamics_model"):
            if hidden is not None:
                hidden = self.layer_norm(hidden)
                # Discrete solver: input is x_t, hidden is h_t
                state = torch.cat([hidden, input], dim=1)
            else:
                input = self.layer_norm(input)
                # Continuous solver: input is h_t
                state = input

            derivative = self.dynamics_model(state)
        else:
            derivative = super().forward(input, hidden)

        if self.dt and hidden is not None:
            # Discrete solver expects the next state h_{t+1}
            output = hidden + self.dt * derivative
            return output, None
        else:
            # Continuous solver expects the derivative dh/dt
            return derivative

class MiniMLPDynamicsCell(nn.RNNCell):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool = True,
        dt: float = 0.1,
        mlp_hidden_dims: list[int] = [128],
        mlp_activation: str = 'tanh',
        neurons: dict = {},
        noise: int = 0,
        softplus: bool=True,
        latent_sizes: Optional[dict[str,int]]=None,
        compositional_func: bool = False,
        stimulated_populations: Optional[list[str]] = None
    ):
        super().__init__(input_size, hidden_size, bias)
        
        # Make sure mlp hidden dims is same length as neurons
        if len(mlp_hidden_dims) != len(neurons):
            if len(mlp_hidden_dims) == 1:
                mlp_hidden_dims = [mlp_hidden_dims[0]] * len(neurons)
            else:
                raise ValueError("mlp_hidden_dims must be same length as " \
                "neurons or a single value")
        
        # Set attributes
        self.dt = dt
        self.input_size = input_size
        self.neurons = neurons
        self.latent_sizes = latent_sizes or {pop: 1 for pop in neurons}
        self.total_latent_size = sum(self.latent_sizes.values())
        self.compositional_func = compositional_func
        self.stimulated_populations = stimulated_populations
        
        # Create dynamics model
        self.dynamics_model = nn.ModuleDict({
                neuron: MiniMLPDynamics(
                    mlp_hidden_dim=mlp_hidden_dims[i], bias=bias,
                    neurons=neurons, noise=noise, target=neuron,
                    softplus=softplus, latent_sizes=self.latent_sizes,
                    act_func=mlp_activation,
                    input=self.input_size,
                    stimulated_populations=self.stimulated_populations
                    ) for i, neuron in enumerate(self.neurons)
            })
        
        # If compositional function is enabled, create composition layers
        if compositional_func:
            self.composition = nn.ModuleDict({
                n: nn.Sequential(
                    MLPLayer(int(tgt_l_size*2), mlp_hidden_dims[i]),
                    nn.Tanh(),
                    MLPLayer(mlp_hidden_dims[i], tgt_l_size)
                ) for i, (n, tgt_l_size) in enumerate(self.latent_sizes.items())
            })
         
    def forward(self, input, hidden=None, time=None):
        if hasattr(self, "dynamics_model"):
            # Solver mode: hidden used for discrete solver,
            # not used for continuous solver
            if hidden is not None:
                # Discrete solver: input is x_t, hidden is h_t
                state = torch.cat([hidden, input], dim=1)
            else:
                state = input

            influences = {}
            derivative_parts = []

            # Set latent start index
            latent_start = 0
            for i, neuron in enumerate(self.dynamics_model.keys()):
                # Get output, influence from the dynamics model for this neuron
                out, influence = self.dynamics_model[neuron](state)
                # If compositional function is enabled, apply composition
                if self.compositional_func:
                    # Get target latent size for this neuron
                    target_latent_size = self.latent_sizes[neuron]
                    # Concatenate original input with MiniMLP output
                    combined_input = torch.cat(
                        [out,
                         state[
                             :, latent_start:latent_start + target_latent_size
                            ]
                        ], dim=1
                    )
                    # Pass through composition layers
                    comp_out = self.composition[neuron](combined_input)
                    # Update influence with compositional output
                    influence[neuron] = comp_out
                    # Update start for next population
                    latent_start += target_latent_size

                    # Append output to derivative parts
                    derivative_parts.append(comp_out)
                    # Add influence to influences
                    influences[neuron] = influence
                else:
                    derivative_parts.append(out)
                    influences[neuron] = influence

            derivative = torch.cat(derivative_parts, dim=1)
        else:
            derivative = super().forward(input, hidden)
        
        if self.dt and hidden is not None:
            # Discrete solver expects the next state h_{t+1}
            output = hidden + self.dt * derivative
            return output, influences
        else:
            # Continuous solver expects the derivative dh/dt
            return derivative