"""
Optimizer classes for HelioX training.

This module provides optimizer abstractions similar to PyTorch/TensorFlow,
automating gradient accumulation and weight updates across batch training.
"""

import numpy as np
from typing import List, Dict, Tuple, Optional
from abc import ABC, abstractmethod
from neuron import h


class Optimizer(ABC):
    """
    Base optimizer class for managing weight updates.
    
    This abstract class defines the interface for all optimizers,
    handling gradient collection and weight updates across multiple networks.
    
    Note: Gradients from extract_gradients include impedance correction factors
    but NOT learning rate. The optimizer's learning_rate controls the actual
    learning rate for weight updates.
    """
    
    def __init__(self, learning_rate: float = 0.005):
        """
        Initialize the optimizer.
        
        Args:
            learning_rate: Learning rate for weight updates (default 0.005)
                          Note: gradients from extract_gradients now only include
            impedance correction factors, not learning rate
        """
        self.learning_rate = learning_rate
        self.iteration = 0
        self.gradient_backend = "auto"
        self.record_time = 30.0
        # Allow callers to force Python-side optimizers even when HelioX exposes one.
        self.use_heliox_optimizer = True

    def set_gradient_backend(self, backend: str):
        """
        Select which backend to use when collecting gradients.
        
        Args:
            backend: 'auto' or 'heliox'
        """
        self.gradient_backend = backend
        
    @abstractmethod
    def step(self, networks: List) -> None:
        """
        Perform one optimization step.
        
        Args:
            networks: List of network instances to update
        """
        pass
    
    def zero_grad(self, networks: List) -> None:
        """
        Zero out gradients in all networks.
        
        Note: Gradients are accumulated in the mechanisms themselves,
        so this might not always be necessary.
        
        Args:
            networks: List of network instances
        """
        # Gradients are managed in the mechanisms; placeholder for future use.
        # This is a placeholder for future use if needed
        pass

    def _try_heliox_optimizer(self, networks: List) -> bool:
        """Attempt to execute the HelioX-side optimizer. Return True if used."""
        if not getattr(self, "use_heliox_optimizer", True):
            return False
        if not networks:
            return False

        first_network = networks[0]
        backend = getattr(first_network, "backend", None)
        if backend is None or not getattr(backend, "enable_heliox", False):
            return False

        if len(networks) > 1:
            shared_backend = all(getattr(net, "backend", None) is backend for net in networks)
            if not shared_backend:
                return False
            configure_fn = getattr(backend, "configure_batch_optimizer", None)
            if callable(configure_fn):
                initialized = getattr(backend, "initialized", False)
                batch_configured = getattr(backend, "batch_configured", False)
                if not initialized and not batch_configured:
                    try:
                        configure_fn(networks)
                    except Exception as exc:  # pylint: disable=broad-except
                        print(f"Warning: failed to configure HelioX batch optimizer: {exc}")
                        return False

        backend_pref = self.gradient_backend
        if backend_pref not in ("auto", "heliox"):
            return False

        if not getattr(backend, "optimizer_ready", False):
            return False

        executed = backend.optimizer_step(self.learning_rate, self.record_time, h.dt)
        return bool(executed)

    def collect_gradients(self, networks: List) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
        """
        Collect gradients from all networks.
        
        Args:
            networks: List of network instances
            
        Returns:
            Dictionary mapping layer info to accumulated gradients
        """
        if not networks:
            return {}
        
        first_network = networks[0]
        batch_size = len(networks)
        
        # Detect network type and handle accordingly
        if hasattr(first_network, '__class__'):
            # Check if it's Sequential (duck typing to avoid import)
            if first_network.__class__.__name__ == 'Sequential':
                return self._collect_gradients_sequential(networks)
            # Otherwise assume it's Network_parallel
            else:
                return self._collect_gradients_network_parallel(networks)
        else:
            # Fallback to Network_parallel behavior
            return self._collect_gradients_network_parallel(networks)
    
    def _collect_gradients_sequential(self, networks: List) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
        """
        Collect gradients from Sequential networks.
        
        Args:
            networks: List of Sequential instances
            
        Returns:
            Dictionary mapping layer info to accumulated gradients
        """
        if not networks:
            return {}
        
        batch_size = len(networks)
        gradients = {}
        first_network = networks[0]
        backend_key = self.gradient_backend
        if backend_key == 'auto':
            backend_key = 'heliox'
        if backend_key != 'heliox':
            raise ValueError(f"Unsupported backend '{backend_key}'")
        
        # Process each network and accumulate gradients
        for network in networks:
            # Extract gradients from the network
            network_gradients = network.extract_gradients(h.dt, record_time=self.record_time, backend=backend_key)
            
            # Accumulate gradients for each layer
            for layer_name, grad_data in network_gradients.items():
                # Handle nested dictionaries if a backend map is returned
                if 'weight' not in grad_data:
                    if backend_key not in grad_data:
                        raise ValueError(f"Gradient data for backend '{backend_key}' not available")
                    grad_data = grad_data[backend_key]

                if layer_name not in gradients:
                    # Initialize accumulator for this layer
                    gradients[layer_name] = {
                        'sum_weight': np.zeros_like(grad_data['weight']),
                        'sum_bias': np.zeros_like(grad_data['bias']),
                        'weight_shape': grad_data['weight'].shape,
                        'bias_shape': grad_data['bias'].shape
                    }
                
                # Accumulate gradients
                gradients[layer_name]['sum_weight'] += grad_data['weight']
                gradients[layer_name]['sum_bias'] += grad_data['bias']
        
        # Average gradients and format for return
        averaged_gradients = {}
        for layer_name, grad_data in gradients.items():
            avg_dw = grad_data['sum_weight'] / batch_size
            avg_db = grad_data['sum_bias'] / batch_size
            # Store in format compatible with update_weights
            averaged_gradients[layer_name] = (avg_dw, avg_db, layer_name, None, None)
        
        return averaged_gradients
    
    def _collect_gradients_network_parallel(self, networks: List) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
        """
        Collect gradients from Network_parallel networks.
        
        Args:
            networks: List of Network_parallel instances
            
        Returns:
            Dictionary mapping layer info to accumulated gradients
        """
        if not networks:
            return {}
        
        first_network = networks[0]
        batch_size = len(networks)
        
        # Build trainable layers info
        trainable_layers_info = [
            (1, 'in2hd_w', 'in2hd_b'),  # Hidden layer (layers[1])
            (2, 'h2out_w', 'h2out_b')   # Output layer (layers[2])
        ]
        
        gradients = {}
        backend_key = self.gradient_backend
        if backend_key == 'auto':
            backend_key = 'heliox'
        if backend_key != 'heliox':
            raise ValueError(f"Unsupported backend '{backend_key}'")
        
        # Collect and accumulate gradients for each layer
        for layer_idx, weight_attr, bias_attr in trainable_layers_info:
            # Initialize gradient accumulators
            weight_shape = getattr(first_network, weight_attr).shape
            bias_shape = getattr(first_network, bias_attr).shape
            
            sum_dw = np.zeros(weight_shape)
            sum_db = np.zeros(bias_shape)
            
            # Accumulate gradients from all networks
            for network in networks:
                dw, db = network.layers[layer_idx].extract_gradients(h.dt, record_time=self.record_time, backend=backend_key)
                sum_dw += dw
                sum_db += db
            
            # Store averaged gradients
            layer_key = f"layer_{layer_idx}_{weight_attr}_{bias_attr}"
            gradients[layer_key] = (sum_dw / batch_size, sum_db / batch_size, 
                                   layer_idx, weight_attr, bias_attr)
        
        return gradients
    
    def update_weights(self, networks: List, gradients: Dict) -> None:
        """
        Update weights in all networks using collected gradients.
        
        Args:
            networks: List of network instances
            gradients: Dictionary of gradients from collect_gradients
        """
        if not networks or not gradients:
            return
        
        first_network = networks[0]
        
        # Detect network type and handle accordingly
        if hasattr(first_network, '__class__'):
            # Check if it's Sequential (duck typing to avoid import)
            if first_network.__class__.__name__ == 'Sequential':
                self._update_weights_sequential(networks, gradients)
            else:
                self._update_weights_network_parallel(networks, gradients)
        else:
            # Fallback to Network_parallel behavior
            self._update_weights_network_parallel(networks, gradients)
        
        self.iteration += 1
    
    def _update_weights_sequential(self, networks: List, gradients: Dict) -> None:
        """
        Update weights in Sequential networks.
        
        Args:
            networks: List of Sequential instances
            gradients: Dictionary of gradients
        """
        if not networks or not gradients:
            return
        
        # Calculate new weights for each layer
        new_weights = {}
        for layer_name, (avg_dw, avg_db, _, _, _) in gradients.items():
            # Get current weights from first network
            first_network = networks[0]
            layer = first_network.get_layer_by_name(layer_name)
            
            if layer is None:
                # Try to find the layer by checking all layers
                for l in first_network.layers:
                    if hasattr(l, 'layer_name') and l.layer_name == layer_name:
                        layer = l
                        break
            
            if layer is None:
                continue
            
            # Get current weights from the network's weight storage
            if layer_name in first_network.weights:
                current_weight = first_network.weights[layer_name]['weight']
                current_bias = first_network.weights[layer_name]['bias']
                
                # Calculate new weights
                new_weight = current_weight - self.learning_rate * avg_dw
                new_bias = current_bias - self.learning_rate * avg_db
                
                new_weights[layer_name] = {
                    'weight': new_weight,
                    'bias': new_bias
                }
        
        # Update all networks
        for network in networks:
            # Update weight storage
            for layer_name, weight_data in new_weights.items():
                if layer_name in network.weights:
                    network.weights[layer_name]['weight'] = np.copy(weight_data['weight'])
                    network.weights[layer_name]['bias'] = np.copy(weight_data['bias'])
                
                # Load weights to the layer
                layer = network.get_layer_by_name(layer_name)
                if layer is None:
                    # Try to find the layer by checking all layers
                    for l in network.layers:
                        if hasattr(l, 'layer_name') and l.layer_name == layer_name:
                            layer = l
                            break
                
                if layer and hasattr(layer, 'load_weights'):
                    layer.load_weights(weight_data['weight'], weight_data['bias'])
    
    def _update_weights_network_parallel(self, networks: List, gradients: Dict) -> None:
        """
        Update weights in Network_parallel networks.
        
        Args:
            networks: List of Network_parallel instances
            gradients: Dictionary of gradients
        """
        if not networks or not gradients:
            return
        
        first_network = networks[0]
        
        # Apply updates to each layer
        for layer_key, (avg_dw, avg_db, layer_idx, weight_attr, bias_attr) in gradients.items():
            # Calculate new weights
            current_weight = getattr(first_network, weight_attr)
            current_bias = getattr(first_network, bias_attr)
            
            new_weight = current_weight - self.learning_rate * avg_dw
            new_bias = current_bias - self.learning_rate * avg_db
            
            # Update all networks
            for network in networks:
                setattr(network, weight_attr, np.copy(new_weight))
                setattr(network, bias_attr, np.copy(new_bias))
                
                # Load weights into the layer
                network.layers[layer_idx].load_weights(new_weight, new_bias)


class SGD(Optimizer):
    """
    Stochastic Gradient Descent optimizer.
    
    This is the basic SGD optimizer that updates weights using
    the gradient scaled by the learning rate.
    """
    
    def __init__(self, learning_rate: float = 0.005):
        """
        Initialize SGD optimizer.
        
        Args:
            learning_rate: Learning rate (default 0.005)
        """
        super().__init__(learning_rate)
    
    def step(self, networks: List) -> None:
        """
        Perform one SGD optimization step.
        
        Args:
            networks: List of network instances to update
        """
        if self._try_heliox_optimizer(networks):
            self.iteration += 1
            return

        # Collect gradients from all networks
        gradients = self.collect_gradients(networks)

        # Update weights
        self.update_weights(networks, gradients)


class SGDMomentum(Optimizer):
    """
    SGD with momentum optimizer.
    
    This optimizer maintains a velocity for each parameter and
    uses momentum to accelerate convergence.
    """
    
    def __init__(self, learning_rate: float = 0.005, momentum: float = 0.9):
        """
        Initialize SGD with momentum.
        
        Args:
            learning_rate: Learning rate (default 0.005)
            momentum: Momentum factor (typically 0.9)
        """
        super().__init__(learning_rate)
        self.momentum = momentum
        self.velocity = {}
    
    def step(self, networks: List) -> None:
        """
        Perform one SGD with momentum optimization step.
        
        Args:
            networks: List of network instances to update
        """
        if self._try_heliox_optimizer(networks):
            self.iteration += 1
            return

        # Collect gradients
        gradients = self.collect_gradients(networks)

        if not networks or not gradients:
            return
        
        first_network = networks[0]
        
        # Apply momentum updates
        for layer_key, (avg_dw, avg_db, layer_idx_or_name, weight_attr, bias_attr) in gradients.items():
            # Initialize velocity if needed
            if layer_key not in self.velocity:
                self.velocity[layer_key] = {
                    'weight': np.zeros_like(avg_dw),
                    'bias': np.zeros_like(avg_db)
                }
            
            # Update velocity
            self.velocity[layer_key]['weight'] = (
                self.momentum * self.velocity[layer_key]['weight'] - 
                self.learning_rate * avg_dw
            )
            self.velocity[layer_key]['bias'] = (
                self.momentum * self.velocity[layer_key]['bias'] - 
                self.learning_rate * avg_db
            )
            
            # Get current weights and apply updates based on network type
            if first_network.__class__.__name__ == 'Sequential':
                # Sequential network
                layer_name = layer_idx_or_name
                if layer_name in first_network.weights:
                    current_weight = first_network.weights[layer_name]['weight']
                    current_bias = first_network.weights[layer_name]['bias']
                    
                    # Calculate new weights with momentum
                    new_weight = current_weight + self.velocity[layer_key]['weight']
                    new_bias = current_bias + self.velocity[layer_key]['bias']
                    
                    # Update all networks
                    for network in networks:
                        network.weights[layer_name]['weight'] = np.copy(new_weight)
                        network.weights[layer_name]['bias'] = np.copy(new_bias)
                        
                        # Load weights to the layer
                        layer = network.get_layer_by_name(layer_name)
                        if layer is None:
                            for l in network.layers:
                                if hasattr(l, 'layer_name') and l.layer_name == layer_name:
                                    layer = l
                                    break
                        
                        if layer and hasattr(layer, 'load_weights'):
                            layer.load_weights(new_weight, new_bias)
            else:
                # Network_parallel
                layer_idx = layer_idx_or_name
                current_weight = getattr(first_network, weight_attr)
                current_bias = getattr(first_network, bias_attr)
                
                # Calculate new weights with momentum
                new_weight = current_weight + self.velocity[layer_key]['weight']
                new_bias = current_bias + self.velocity[layer_key]['bias']
                
                # Update all networks
                for network in networks:
                    setattr(network, weight_attr, np.copy(new_weight))
                    setattr(network, bias_attr, np.copy(new_bias))
                    network.layers[layer_idx].load_weights(new_weight, new_bias)
        
        self.iteration += 1


class Adam(Optimizer):
    """
    Adam optimizer (Adaptive Moment Estimation).
    
    Adam combines the benefits of AdaGrad and RMSProp, maintaining
    both first and second moment estimates of the gradients.
    """
    
    def __init__(self, 
                 learning_rate: float = 0.001,
                 beta1: float = 0.9,
                 beta2: float = 0.999,
                 epsilon: float = 1e-8):
        """
        Initialize Adam optimizer.
        
        Args:
            learning_rate: Learning rate (default 0.001 for Adam)
            beta1: Exponential decay rate for first moment estimates
            beta2: Exponential decay rate for second moment estimates
            epsilon: Small value to prevent division by zero
        """
        super().__init__(learning_rate)
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.m = {}  # First moment estimate
        self.v = {}  # Second moment estimate
    
    def step(self, networks: List) -> None:
        """
        Perform one Adam optimization step.
        
        Args:
            networks: List of network instances to update
        """
        if self._try_heliox_optimizer(networks):
            self.iteration += 1
            return

        # Collect gradients
        gradients = self.collect_gradients(networks)

        if not networks or not gradients:
            return
        
        first_network = networks[0]
        self.iteration += 1
        
        # Apply Adam updates
        for layer_key, (avg_dw, avg_db, layer_idx_or_name, weight_attr, bias_attr) in gradients.items():
            # Initialize moments if needed
            if layer_key not in self.m:
                self.m[layer_key] = {
                    'weight': np.zeros_like(avg_dw),
                    'bias': np.zeros_like(avg_db)
                }
                self.v[layer_key] = {
                    'weight': np.zeros_like(avg_dw),
                    'bias': np.zeros_like(avg_db)
                }
            
            # Update biased first moment estimate
            self.m[layer_key]['weight'] = (
                self.beta1 * self.m[layer_key]['weight'] + 
                (1 - self.beta1) * avg_dw
            )
            self.m[layer_key]['bias'] = (
                self.beta1 * self.m[layer_key]['bias'] + 
                (1 - self.beta1) * avg_db
            )
            
            # Update biased second raw moment estimate
            self.v[layer_key]['weight'] = (
                self.beta2 * self.v[layer_key]['weight'] + 
                (1 - self.beta2) * avg_dw**2
            )
            self.v[layer_key]['bias'] = (
                self.beta2 * self.v[layer_key]['bias'] + 
                (1 - self.beta2) * avg_db**2
            )
            
            # Compute bias-corrected moments
            m_hat_weight = self.m[layer_key]['weight'] / (1 - self.beta1**self.iteration)
            m_hat_bias = self.m[layer_key]['bias'] / (1 - self.beta1**self.iteration)
            v_hat_weight = self.v[layer_key]['weight'] / (1 - self.beta2**self.iteration)
            v_hat_bias = self.v[layer_key]['bias'] / (1 - self.beta2**self.iteration)
            
            # Get current weights and apply updates based on network type
            if first_network.__class__.__name__ == 'Sequential':
                # Sequential network
                layer_name = layer_idx_or_name
                if layer_name in first_network.weights:
                    current_weight = first_network.weights[layer_name]['weight']
                    current_bias = first_network.weights[layer_name]['bias']
                    
                    # Calculate updates
                    weight_update = self.learning_rate * m_hat_weight / (np.sqrt(v_hat_weight) + self.epsilon)
                    bias_update = self.learning_rate * m_hat_bias / (np.sqrt(v_hat_bias) + self.epsilon)
                    
                    # Apply updates
                    new_weight = current_weight - weight_update
                    new_bias = current_bias - bias_update
                    
                    # Update all networks
                    for network in networks:
                        network.weights[layer_name]['weight'] = np.copy(new_weight)
                        network.weights[layer_name]['bias'] = np.copy(new_bias)
                        
                        # Load weights to the layer
                        layer = network.get_layer_by_name(layer_name)
                        if layer is None:
                            for l in network.layers:
                                if hasattr(l, 'layer_name') and l.layer_name == layer_name:
                                    layer = l
                                    break
                        
                        if layer and hasattr(layer, 'load_weights'):
                            layer.load_weights(new_weight, new_bias)
            else:
                # Network_parallel
                layer_idx = layer_idx_or_name
                current_weight = getattr(first_network, weight_attr)
                current_bias = getattr(first_network, bias_attr)
                
                # Calculate updates
                weight_update = self.learning_rate * m_hat_weight / (np.sqrt(v_hat_weight) + self.epsilon)
                bias_update = self.learning_rate * m_hat_bias / (np.sqrt(v_hat_bias) + self.epsilon)
                
                # Apply updates
                new_weight = current_weight - weight_update
                new_bias = current_bias - bias_update
                
                # Update all networks
                for network in networks:
                    setattr(network, weight_attr, np.copy(new_weight))
                    setattr(network, bias_attr, np.copy(new_bias))
                    network.layers[layer_idx].load_weights(new_weight, new_bias)
