"""
RichLazyControlledRNN Model

This module implements a rank-controlled recurrent neural network that:
- Uses rich/lazy weight matrices for recurrent connections
- Includes layer normalization
- Supports connectivity dimensionality computation
- Provides both logits and softmax outputs
"""

import torch
import torch.nn as nn
import numpy as np
import math
from typing import Optional, Tuple

class RichLazyControlledRNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int, scale: float = 1.0, input_rank: Optional[int] = None):
        super().__init__()
        self.hidden_size = hidden_size
        self.scale = scale
        
        # Standard Gaussian initialization for input weights
        self.input2h = nn.Linear(input_size, hidden_size)
        
        # Create recurrent layer with Xavier initialization and scale parameter
        self.h2h = nn.Linear(hidden_size, hidden_size)
        with torch.no_grad():
            W = torch.randn(hidden_size, hidden_size) * (scale * math.sqrt(2.0 / (hidden_size + hidden_size)))
            self.h2h.weight.copy_(W)

        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_size)
        
        # Output layer
        self.output = nn.Linear(hidden_size, output_size)

        # Softmax to analyze error patterns
        self.softmax = nn.Softmax(dim=-1)
                
    def forward(self, x: torch.Tensor, h: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input tensor of shape [seq_len, batch_size, input_size]
            h: Optional initial hidden state
            
        Returns:
            output: Tensor of shape [seq_len, batch_size, output_size]
            softmax output probabilities
            hidden: Tensor of shape [seq_len, batch_size, hidden_size] - all hidden states
        """
        if h is None:
            h = torch.zeros(x.size(1), self.hidden_size, device=x.device)
                
        outputs = []
        hidden_states = []
        
        for t in range(x.size(0)):
            
            # Combined input and recurrent
            combined = self.input2h(x[t]) + self.h2h(h)
            
            # Apply layer norm before nonlinearity
            normalized = self.layer_norm(combined)
            
            # ReLU activation
            h = torch.relu(normalized)

            hidden_states.append(h)
            
            out = self.output(h)
            outputs.append(out)
        
        stacked_outputs = torch.stack(outputs)
        stacked_hidden = torch.stack(hidden_states)
        return stacked_outputs, self.softmax(stacked_outputs), stacked_hidden
    
    def get_h2h_weights_dim(self) -> float:
        """Compute participation ratio (PR) of recurrent weights."""
        with torch.no_grad():
            _, s, _ = torch.svd(self.h2h.weight)
            nom = torch.sum(s)
            denom = torch.sum(s**2)
            return float((nom**2)/denom)

__all__ = ['RichLazyControlledRNN']