"""
Sequential Neural Network Builder for HelioX Training

This module provides a Sequential class that simplifies neural network construction,
similar to PyTorch/TensorFlow APIs. It automatically handles layer connections,
gradient routing, and network orchestration.

Example usage:
    model = Sequential([
        InputLayer(784),
        DenseLayer(64, 'passive_hpc'),
        DenseLayer(10, 'point'),
        SoftMaxLayer()
    ])
    model.build()
    model.set_stim(img, target)
"""

import numpy as np
from neuron import h
from typing import List, Union, Optional, Dict, Any

from hybrid_backend import BackendConfig, HybridBackend

from id_manager import NetworkIDManager
from layers import InputLayer, DenseLayer, OutputLayer, SoftMaxLayer
from connection_pattern import ConnectionPattern


class Sequential:
    """
    Sequential neural network builder that automatically handles layer connections.
    
    This class provides a simplified interface for building neural networks with
    automatic layer connection management, gradient routing, and ID management.
    It hides the complexity of manual connection setup while maintaining full
    compatibility with the existing training framework.
    """
    
    def __init__(self, 
                 layers: List[Union[InputLayer, DenseLayer, SoftMaxLayer]], 
                 seed: int = 1234,
                 connection_pattern: Optional[ConnectionPattern] = None,
                 backend_config: Optional[BackendConfig] = None,
                 backend: Optional[HybridBackend] = None):
        """
        Initialize the Sequential neural network.
        
        Args:
            layers: List of layer objects in order (InputLayer, DenseLayer(s), SoftMaxLayer)
            seed: Random seed for reproducibility
            connection_pattern: Optional ConnectionPattern manager for consistent connections
        """
        self.layers = layers
        self.seed = seed
        self.connection_pattern = connection_pattern or ConnectionPattern(seed=seed)
        self.backend = backend if backend is not None else HybridBackend(backend_config)
        
        # Validate layer sequence
        self._validate_layers()
        
        # Initialize network components
        self.id_manager = NetworkIDManager()
        self.pc = h.ParallelContext()
        self.nhost = self.pc.nhost()
        self.ihost = self.pc.id()
        
        # Network state
        self._built = False
        self._weights_initialized = False
        
        # Connection information (populated during build)
        self._gradient_sids_registry = {}
        self._connection_info = {}
        
        # Weight storage (for compatibility with existing training code)
        self.weights = {}
        self.gradients = {}
        
        # Layer references for easy access
        self.input_layer = None
        self.hidden_layers = []
        self.output_layer = None
        self.softmax_layer = None
        
        self._categorize_layers()
        self._attach_backend()
    
    def _validate_layers(self):
        """Validate the layer sequence is correct."""
        if not self.layers:
            raise ValueError("Sequential requires at least one layer")
        
        if not isinstance(self.layers[0], InputLayer):
            raise ValueError("First layer must be InputLayer")
        
        if not isinstance(self.layers[-1], SoftMaxLayer):
            raise ValueError("Last layer must be SoftMaxLayer")
        
        # Check that we have appropriate dense layers
        dense_layers = [layer for layer in self.layers[1:-1] if isinstance(layer, DenseLayer)]
        if len(dense_layers) == 0:
            raise ValueError("Sequential requires at least one DenseLayer between InputLayer and SoftMaxLayer")
    
    def _categorize_layers(self):
        """Categorize layers for easier access."""
        self.input_layer = self.layers[0]
        self.softmax_layer = self.layers[-1]
        
        # Middle layers are dense layers
        middle_layers = self.layers[1:-1]
        for i, layer in enumerate(middle_layers):
            if isinstance(layer, DenseLayer):
                self.hidden_layers.append(layer)
                # Ensure unique layer names
                if not hasattr(layer, '_name_set'):
                    if i == len(middle_layers) - 1:
                        # Last dense layer is output
                        layer.layer_name = 'output'
                        layer.layer_role = 'output'
                    else:
                        # Hidden layers
                        layer.layer_name = f'hidden_{i}'
                        layer.layer_role = 'hidden'
                    layer._name_set = True
        
        # The last dense layer is treated as output layer
        if self.hidden_layers:
            self.output_layer = self.hidden_layers[-1]
            self.hidden_layers = self.hidden_layers[:-1]  # Remove output layer from hidden layers

    def _attach_backend(self):
        """Attach backend controller to all layers that support it."""
        for layer in self.layers:
            if hasattr(layer, "set_backend"):
                layer.set_backend(self.backend)
    
    def build(self):
        """
        Build the neural network by creating neurons and connections.
        
        This method orchestrates the entire network construction process:
        1. Create neurons for all layers
        2. Setup SoftMax mechanism
        3. Create layer connections with gradient routing
        4. Initialize stimulators
        """
        if self._built:
            raise RuntimeError("Network already built. Create a new Sequential instance to rebuild.")
        
        registration_started = False
        if hasattr(self.backend, "begin_network_registration"):
            self.backend.begin_network_registration(self)
            registration_started = True

        try:
            # Phase 1: Create neurons for all layers
            self._create_neurons()
            
            # Phase 2: Setup SoftMax mechanism
            self._setup_softmax()
            
            # Phase 3: Create connections between layers
            self._create_connections()
            
            # Phase 4: Initialize stimulators
            self._initialize_stimulators()
        finally:
            if registration_started and hasattr(self.backend, "end_network_registration"):
                self.backend.end_network_registration(self)
        
        self._built = True
        print(f"Sequential network built successfully with {len(self.layers)} layers")
    
    def _create_neurons(self):
        """Create neurons for all layers."""
        # Create input layer neurons
        self.input_layer.phase1_create_neurons(self.pc, self.id_manager, self.ihost)
        
        # Create hidden layer neurons
        for layer in self.hidden_layers:
            layer.phase1_create_neurons(self.pc, self.id_manager, self.ihost)
        
        # Create output layer neurons
        if self.output_layer:
            self.output_layer.phase1_create_neurons(self.pc, self.id_manager, self.ihost)
        
        # Create SoftMax neuron
        self.softmax_layer.phase1_create_neuron(self.pc, self.id_manager, self.ihost)
    
    def _setup_softmax(self):
        """Setup SoftMax mechanism and connections."""
        # Get output layer voltage SIDs
        output_voltage_sids = self.output_layer.get_voltage_sids()
        
        # Setup SoftMax mechanism
        self.softmax_layer.phase2_setup_mechanism(self.pc, self.id_manager)
        self.softmax_layer.phase3_connect_inputs(self.pc, output_voltage_sids)
    
    def _create_connections(self):
        """Create connections between layers with gradient routing."""
        # Get SoftMax gradient SIDs
        softmax_grad_sids = self.softmax_layer.get_grad_sids()
        
        # Register gradient SIDs for output to hidden connections
        output2hidden_grad_sids = self._register_output_to_hidden_gradients()
        
        # Connect layers sequentially
        prev_layer = self.input_layer
        
        # Process all layers between input and softmax
        all_middle_layers = self.hidden_layers + ([self.output_layer] if self.output_layer else [])
        
        for i, layer in enumerate(all_middle_layers):
            # Determine the next layer
            if i < len(all_middle_layers) - 1:
                # There's another layer after this one
                next_layer_obj = all_middle_layers[i + 1]
            else:
                # This is the last layer before SoftMax
                next_layer_obj = self.softmax_layer
            
            # Prepare connection info based on layer position
            if layer == self.output_layer:
                # This is the output layer connecting to SoftMax
                next_layer_info = {
                    'softmax_grad_sids': softmax_grad_sids,
                    'output2hidden_grad_sids': output2hidden_grad_sids
                }
                connection_pattern, connection_locs = self._get_connection_pattern(prev_layer, layer)
            else:
                # This is a hidden layer
                if next_layer_obj == self.softmax_layer:
                    # Hidden layer that acts as output (connects directly to SoftMax)
                    next_layer_info = {
                        'softmax_grad_sids': softmax_grad_sids,
                        'output2hidden_grad_sids': output2hidden_grad_sids
                    }
                else:
                    # Regular hidden layer - needs aggregator for gradients from next layer
                    hidden_idx = self.hidden_layers.index(layer)
                    next_layer_info = self._get_hidden_layer_next_info(hidden_idx, output2hidden_grad_sids)
                connection_pattern, connection_locs = self._get_connection_pattern(prev_layer, layer)
            
            # Create connections
            layer.connect_layers(prev_layer, self.pc, self.id_manager, 
                               next_layer_info, next_layer=next_layer_obj,
                               connection_pattern=connection_pattern, 
                               connection_locs=connection_locs)
            
            prev_layer = layer
    
    def _register_output_to_hidden_gradients(self):
        """Register gradient SIDs for output to hidden layer connections."""
        n_out = self.output_layer.n_neurons
        output2hidden_grad_sids = []
        
        # Find the layer that connects to output layer
        connecting_layer = self.hidden_layers[-1] if self.hidden_layers else self.input_layer
        n_connecting = connecting_layer.n_neurons
        
        for i in range(n_out):
            grad_sids_for_output_i = self.id_manager.register_sids_batch(
                f'output{i}_to_prev_grad', n_connecting
            )
            output2hidden_grad_sids.append(grad_sids_for_output_i)
        
        return output2hidden_grad_sids
    
    def _get_hidden_layer_next_info(self, layer_idx: int, output2hidden_grad_sids: List[List[int]]):
        """Get next layer info for hidden layer connections."""
        n_out = self.output_layer.n_neurons
        
        # Register aggregated gradient SIDs for this hidden layer
        hidden_layer = self.hidden_layers[layer_idx]
        hidden_aggregated_grad_sids = self.id_manager.register_sids_batch(
            f'hidden{layer_idx}_aggregated_grad', hidden_layer.n_neurons
        )
        
        return {
            'n_neurons': n_out,
            'grad_sids': output2hidden_grad_sids,
            'aggregated_grad_sids': hidden_aggregated_grad_sids
        }
    
    def _get_connection_pattern(self, prev_layer: Any, current_layer: DenseLayer):
        """Get connection pattern for layer connections."""
        if current_layer.neuron_type != 'passive_hpc':
            return None, None
        
        # Get connection parameters
        n_prev = prev_layer.n_neurons
        n_current = current_layer.n_neurons
        num_proj_dend = 1  # Default projection per dendrite
        total_dend = current_layer.get_total_dendrites() if hasattr(current_layer, 'get_total_dendrites') else 1
        
        # Generate connection pattern
        pattern_key = f"{prev_layer.layer_name}_to_{current_layer.layer_name}"
        return self.connection_pattern.get_connection_pattern(
            pattern_key=pattern_key,
            n_source=n_prev,
            n_target=n_current,
            projections_per_connection=num_proj_dend,
            total_dendrites=total_dend
        )
    
    def _initialize_stimulators(self):
        """Initialize input layer stimulators."""
        self.input_layer.phase2_create_stimulators(self.pc, self.id_manager, self.ihost)
        
        # Set NetCon parameters for compatibility
        for ncstim in self.input_layer.ncstim_list:
            ncstim.delay = 1
            ncstim.weight[0] = 0.05
    
    def forward(self, simulation_time: float = 50.0):
        """
        Run forward pass through the network.
        
        Args:
            simulation_time: Duration of simulation in milliseconds
        """
        if not self._built:
            raise RuntimeError("Network must be built before running forward pass")
        
        h.finitialize(-65)
        h.continuerun(simulation_time)
    
    def initialize_backends(self, dt: float, v_init: float, export_path: Optional[str] = None):
        """
        Initialize HelioX with the current model export.
        """
        self.backend.initialize(dt=dt, v_init=v_init, export_path=export_path)
        for layer in self.layers:
            if hasattr(layer, "post_backend_init"):
                layer.post_backend_init()
    
    def set_stim(self, img: np.ndarray, tgt: int):
        """
        Set input stimulation and target for training/inference.
        
        Args:
            img: Input image array (flattened for MNIST)
            tgt: Target class index
        """
        if not self._built:
            raise RuntimeError("Network must be built before setting stimulation")
        
        flat_img = img.flatten()
        self.input_layer.set_stim(flat_img, pc=self.pc)
        
        # Set target for SoftMax
        self.softmax_layer.set_target(tgt)
    
    def is_train(self):
        """Set network to training mode."""
        if not self._built:
            raise RuntimeError("Network must be built before setting training mode")
        
        for layer in self.layers:
            if hasattr(layer, 'set_training_mode'):
                layer.set_training_mode(True)
    
    def is_test(self):
        """Set network to test mode."""
        if not self._built:
            raise RuntimeError("Network must be built before setting test mode")
        
        for layer in self.layers:
            if hasattr(layer, 'set_training_mode'):
                layer.set_training_mode(False)
    
    def prepare_for_training(self):
        """
        Prepare network for training by initializing weights if needed.
        This should be called before training starts to ensure weights are available for the optimizer.
        """
        if not self._built:
            raise RuntimeError("Network must be built before preparing for training")
        
        # Initialize weights if not already done
        if not self._weights_initialized:
            self._initialize_weights()
        
        # Load weights to all layers
        self.reset_weights()
    
    def reset_weights(self):
        """Initialize or reset all weights to initial values."""
        if not self._built:
            raise RuntimeError("Network must be built before resetting weights")
        
        # Generate initial weights if not already done
        if not self._weights_initialized:
            self._initialize_weights()
        
        # Load weights to all trainable layers
        for layer in self.layers:
            if hasattr(layer, 'trainable') and layer.trainable:
                layer_name = layer.layer_name
                if layer_name in self.weights:
                    weight_matrix = self.weights[layer_name]['weight']
                    bias_vector = self.weights[layer_name]['bias']
                    layer.load_weights(weight_matrix, bias_vector)
    
    def _initialize_weights(self):
        """Initialize weights for all trainable layers."""
        for i, layer in enumerate(self.layers):
            if not hasattr(layer, 'trainable') or not layer.trainable:
                continue
            
            layer_name = layer.layer_name
            
            if isinstance(layer, DenseLayer):
                # Get previous layer for weight initialization
                prev_layer = self.layers[i-1] if i > 0 else None
                if prev_layer is None:
                    continue
                
                n_prev = prev_layer.n_neurons
                n_current = layer.n_neurons
                
                # Create weight initialization RNG
                weight_seed = self.seed + hash(layer_name) % 1000000
                weight_rng = np.random.default_rng(seed=weight_seed)
                
                # Check if we need 3D weights (hidden layer with morphological neurons)
                if layer.layer_role == 'hidden' and layer.neuron_type == 'passive_hpc':
                    # Hidden layer with morphological neurons - need 3D weights
                    num_proj_dend = 1  # Default
                    limit = np.sqrt(6. / (n_prev * num_proj_dend + n_current))
                    weight_shape = (n_current, n_prev, num_proj_dend)
                    weights = weight_rng.uniform(-limit, limit, weight_shape)
                else:
                    # Point neurons (hidden or output) - need 2D weights
                    limit = np.sqrt(6. / (n_prev + n_current))
                    weight_shape = (n_current, n_prev)
                    weights = weight_rng.uniform(-limit, limit, weight_shape)
                
                bias = np.zeros((n_current,))
                
                self.weights[layer_name] = {
                    'weight': weights,
                    'bias': bias
                }
        
        self._weights_initialized = True
    
    def sync_weights_from(self, source_sequential):
        """Synchronize weights from another Sequential network."""
        if not isinstance(source_sequential, Sequential):
            raise TypeError("source_sequential must be a Sequential instance")
        
        if not source_sequential._weights_initialized:
            raise RuntimeError("Source network weights not initialized")
        
        # Copy weights
        self.weights = {}
        for layer_name, weight_data in source_sequential.weights.items():
            self.weights[layer_name] = {
                'weight': np.copy(weight_data['weight']),
                'bias': np.copy(weight_data['bias'])
            }
        
        self._weights_initialized = True
        
        # Load weights to layers if network is built
        if self._built:
            self.reset_weights()
    
    def load_weights(self, weights_dict: Dict[str, Dict[str, np.ndarray]]):
        """Load weights from external dictionary."""
        self.weights = {}
        for layer_name, weight_data in weights_dict.items():
            self.weights[layer_name] = {
                'weight': np.copy(weight_data['weight']),
                'bias': np.copy(weight_data['bias'])
            }
        
        self._weights_initialized = True
        
        # Load weights to layers if network is built
        if self._built:
            self.reset_weights()
    
    def extract_gradients(self, dt: float = 0.025, record_time: float = 30.0,
                          backend: str = "auto") -> Dict[str, Dict[str, np.ndarray]]:
        """
        Extract gradients from all trainable layers.
        
        Args:
            dt: Time step
            record_time: Recording time
            backend: Backend selector ('auto' or 'heliox')
            
        Returns:
            Dictionary of gradients keyed by layer name
        """
        if not self._built:
            raise RuntimeError("Network must be built before extracting gradients")
        
        # Ensure weights are initialized (needed for optimizer to get current weights)
        if not self._weights_initialized:
            self._initialize_weights()
        
        gradients = {}
        
        for layer in self.layers:
            if hasattr(layer, 'trainable') and layer.trainable and hasattr(layer, 'extract_gradients'):
                layer_name = layer.layer_name
                grad_data = layer.extract_gradients(dt, record_time, backend=backend)
                if grad_data is None:
                    continue
                if isinstance(grad_data, dict):
                    gradients[layer_name] = {
                        key: {'weight': value[0], 'bias': value[1]}
                        for key, value in grad_data.items()
                    }
                else:
                    dw, db = grad_data
                    if dw is not None and db is not None:
                        gradients[layer_name] = {
                            'weight': dw,
                            'bias': db
                        }
        
        self.gradients = gradients
        return gradients
    
    def get_softmax_outputs(self, backend: str = "auto"):
        """Get SoftMax layer outputs."""
        if not self._built:
            raise RuntimeError("Network must be built before getting outputs")
        
        return self.softmax_layer.get_softmax_outputs(backend=backend)
    def compare_backends(self) -> Dict[str, Dict[str, float]]:
        """
        汇总各层在 HelioX 侧的状态差异。
        """
        if self.backend is None or not self.backend.enable_heliox:
            return {}
        if not self._built:
            raise RuntimeError("Network must be built before backend comparison")
        
        comparisons: Dict[str, Dict[str, float]] = {}
        
        def merge(layer_name: str, diff_map: Optional[Dict[str, Dict[str, float]]]):
            if not diff_map:
                return
            for key, metrics in diff_map.items():
                comparisons[f"{layer_name}.{key}"] = metrics
        
        for layer in self.hidden_layers:
            if hasattr(layer, "compute_backend_differences"):
                merge(layer.layer_name, layer.compute_backend_differences())
        
        if self.output_layer and hasattr(self.output_layer, "compute_backend_differences"):
            merge(self.output_layer.layer_name, self.output_layer.compute_backend_differences())
        
        if hasattr(self.softmax_layer, "compute_backend_differences"):
            merge("softmax", self.softmax_layer.compute_backend_differences())
        
        return comparisons
    
    def get_layer_by_name(self, layer_name: str):
        """Get layer by name."""
        for layer in self.layers:
            if hasattr(layer, 'layer_name') and layer.layer_name == layer_name:
                return layer
        return None
    
    def get_layer_by_index(self, index: int):
        """Get layer by index."""
        if 0 <= index < len(self.layers):
            return self.layers[index]
        return None
    
    def summary(self) -> str:
        """Get network summary."""
        summary_lines = ["=" * 60]
        summary_lines.append("Sequential Network Summary")
        summary_lines.append("=" * 60)
        
        total_params = 0
        
        for i, layer in enumerate(self.layers):
            layer_info = f"Layer {i}: "
            
            if isinstance(layer, InputLayer):
                layer_info += f"InputLayer({layer.n_neurons})"
                params = 0
            elif isinstance(layer, DenseLayer):
                layer_info += f"DenseLayer({layer.n_neurons}, '{layer.neuron_type}')"
                if hasattr(layer, 'trainable') and layer.trainable:
                    # Estimate parameters
                    prev_layer = self.layers[i-1] if i > 0 else None
                    if prev_layer:
                        n_prev = prev_layer.n_neurons
                        n_current = layer.n_neurons
                        if layer.layer_role == 'hidden' and layer.neuron_type == 'passive_hpc':
                            params = n_current * n_prev * 1 + n_current  # weights + bias
                        else:
                            params = n_current * n_prev + n_current  # weights + bias
                    else:
                        params = 0
                else:
                    params = 0
            elif isinstance(layer, SoftMaxLayer):
                layer_info += f"SoftMaxLayer({layer.n_classes})"
                params = 0
            else:
                layer_info += f"Unknown Layer"
                params = 0
            
            if params > 0:
                layer_info += f" - Parameters: {params:,}"
                total_params += params
            
            summary_lines.append(layer_info)
        
        summary_lines.append("=" * 60)
        summary_lines.append(f"Total Parameters: {total_params:,}")
        summary_lines.append(f"Built: {'Yes' if self._built else 'No'}")
        summary_lines.append(f"Weights Initialized: {'Yes' if self._weights_initialized else 'No'}")
        summary_lines.append("=" * 60)
        
        return "\n".join(summary_lines)
    
    def __len__(self):
        """Return number of layers."""
        return len(self.layers)
    
    def __getitem__(self, index):
        """Get layer by index."""
        return self.layers[index]
    
    def __repr__(self):
        """String representation."""
        layer_names = [type(layer).__name__ for layer in self.layers]
        return f"Sequential({' -> '.join(layer_names)})"


# Convenience functions for common network architectures
def create_mnist_network(hidden_size: int = 64, 
                        seed: int = 1234, 
                        use_vecstim: bool = True,
                        connection_pattern: Optional[ConnectionPattern] = None,
                        backend_config: Optional[BackendConfig] = None) -> Sequential:
    """
    Create a standard MNIST classification network.
    
    Args:
        hidden_size: Number of hidden layer neurons
        seed: Random seed
        use_vecstim: Use VecStim (True) or NetStim (False) for input
        connection_pattern: Optional connection pattern manager
        
    Returns:
        Sequential network ready for MNIST classification
    """
    return Sequential([
        InputLayer(784, use_vecstim=use_vecstim),
        DenseLayer(hidden_size, 'passive_hpc', 'hidden', 
                  morph_file='2013_03_06_cell11_1125_H41_06.asc', 
                  cm=1.5, layer_role='hidden'),
        DenseLayer(10, 'point', 'output', layer_role='output'),
        SoftMaxLayer(10)
    ], seed=seed, connection_pattern=connection_pattern, backend_config=backend_config)


def create_custom_network(layer_configs: List[Dict[str, Any]], 
                         seed: int = 1234,
                         connection_pattern: Optional[ConnectionPattern] = None,
                         backend_config: Optional[BackendConfig] = None) -> Sequential:
    """
    Create a custom network from layer configuration.
    
    Args:
        layer_configs: List of layer configuration dictionaries
        seed: Random seed
        connection_pattern: Optional connection pattern manager
        
    Returns:
        Sequential network with custom architecture
    """
    layers = []
    
    for config in layer_configs:
        layer_type = config.pop('type')
        
        if layer_type == 'input':
            layers.append(InputLayer(**config))
        elif layer_type == 'dense':
            layers.append(DenseLayer(**config))
        elif layer_type == 'softmax':
            layers.append(SoftMaxLayer(**config))
        else:
            raise ValueError(f"Unknown layer type: {layer_type}")
    
    return Sequential(layers, seed=seed, connection_pattern=connection_pattern, backend_config=backend_config)
