"""
Connection Pattern Manager for HelioX Training

This module provides a centralized manager for connection patterns that can be
shared across multiple neural network instances to ensure consistency during
batch training. This is particularly important when multiple networks need to
use the same random connectivity structure.

"""

import numpy as np
from typing import Dict, Tuple, Any, Optional


class ConnectionPattern:
    """
    Manages connection patterns for neural network layers.
    
    This class generates and caches connection patterns (dendrite indices and
    locations) that can be shared across multiple network instances. It ensures
    that networks with the same configuration use identical connection patterns,
    which is crucial for batch training and reproducibility.
    """
    
    def __init__(self, seed: int = 1234, rng: Optional[np.random.Generator] = None):
        """
        Initialize the ConnectionPattern manager.
        
        Args:
            seed: Random seed for pattern generation
            rng: Optional external random number generator (deprecated, will be ignored)
        """
        self.seed = seed
        # 不再保存rng，每次使用时重新创建以确保独立性
        self._pattern_cache: Dict[str, Dict[str, np.ndarray]] = {}
        
    def get_connection_pattern(
        self,
        pattern_key: str,
        n_source: int,
        n_target: int,
        projections_per_connection: int,
        total_dendrites: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get or generate a connection pattern.
        
        This method returns cached patterns if they exist, or generates new ones
        following the same logic as the original cell.py implementation.
        
        Args:
            pattern_key: Unique identifier for this connection pattern (e.g., "input_to_hidden")
            n_source: Number of source neurons (e.g., input layer size)
            n_target: Number of target neurons (e.g., hidden layer size)
            projections_per_connection: Number of projections per connection (typically 1)
            total_dendrites: Total number of dendrites available on target neurons
            
        Returns:
            Tuple of (dendrite_indices, dendrite_locations):
                - dendrite_indices: Array of shape (n_source, n_target, projections_per_connection)
                  containing dendrite indices for each connection
                - dendrite_locations: Array of shape (n_source, n_target, projections_per_connection)
                  containing normalized locations (0-1) on dendrites
        """
        # Check if pattern already exists in cache
        if pattern_key in self._pattern_cache:
            cached_pattern = self._pattern_cache[pattern_key]
            return cached_pattern['dendrite_indices'], cached_pattern['dendrite_locations']
        
        # Validate input parameters
        if total_dendrites <= 0:
            raise ValueError(f"total_dendrites must be positive, got {total_dendrites}")
        if n_source <= 0 or n_target <= 0:
            raise ValueError(f"Source and target sizes must be positive, got {n_source}, {n_target}")
        if projections_per_connection <= 0:
            raise ValueError(f"projections_per_connection must be positive, got {projections_per_connection}")
        
        # Generate new pattern following the original logic from cell.py
        # 使用独立的种子偏移确保每种操作的独立性
        # pattern_key作为额外的偏移，确保不同层之间的模式不同
        pattern_seed_offset = hash(pattern_key) % 1000000
        
        # 为连接索引创建独立的RNG（种子偏移：1000000）
        conn_rng = np.random.default_rng(seed=self.seed + 1000000 + pattern_seed_offset)
        
        # 为连接位置创建独立的RNG（种子偏移：2000000）
        loc_rng = np.random.default_rng(seed=self.seed + 2000000 + pattern_seed_offset)
        
        shape = (n_source, n_target, projections_per_connection)
        
        # Generate dendrite indices (which dendrite to connect to)
        dendrite_indices = conn_rng.integers(0, total_dendrites, shape)
        
        # Generate dendrite locations (position on the dendrite, 0-1)
        dendrite_locations = loc_rng.random(shape)
        
        # Cache the generated pattern
        self._pattern_cache[pattern_key] = {
            'dendrite_indices': dendrite_indices,
            'dendrite_locations': dendrite_locations,
            'metadata': {
                'n_source': n_source,
                'n_target': n_target,
                'projections_per_connection': projections_per_connection,
                'total_dendrites': total_dendrites
            }
        }
        
        return dendrite_indices, dendrite_locations
    
    def has_pattern(self, pattern_key: str) -> bool:
        """
        Check if a pattern with the given key exists in cache.
        
        Args:
            pattern_key: The key to check
            
        Returns:
            True if pattern exists, False otherwise
        """
        return pattern_key in self._pattern_cache
    
    def get_pattern_metadata(self, pattern_key: str) -> Optional[Dict[str, Any]]:
        """
        Get metadata for a cached pattern.
        
        Args:
            pattern_key: The pattern key
            
        Returns:
            Dictionary containing pattern metadata, or None if pattern doesn't exist
        """
        if pattern_key in self._pattern_cache:
            return self._pattern_cache[pattern_key]['metadata'].copy()
        return None
    
    def clear_cache(self):
        """Clear all cached patterns."""
        self._pattern_cache.clear()
    
    def remove_pattern(self, pattern_key: str) -> bool:
        """
        Remove a specific pattern from cache.
        
        Args:
            pattern_key: The pattern key to remove
            
        Returns:
            True if pattern was removed, False if it didn't exist
        """
        if pattern_key in self._pattern_cache:
            del self._pattern_cache[pattern_key]
            return True
        return False
    
    def get_cache_info(self) -> Dict[str, Any]:
        """
        Get information about cached patterns.
        
        Returns:
            Dictionary containing cache statistics and pattern information
        """
        info = {
            'num_patterns': len(self._pattern_cache),
            'seed': self.seed,
            'patterns': {}
        }
        
        for key, pattern_data in self._pattern_cache.items():
            metadata = pattern_data['metadata']
            info['patterns'][key] = {
                'shape': (metadata['n_source'], metadata['n_target'], metadata['projections_per_connection']),
                'total_dendrites': metadata['total_dendrites']
            }
        
        return info
    
    def validate_pattern_compatibility(
        self,
        pattern_key: str,
        n_source: int,
        n_target: int,
        projections_per_connection: int,
        total_dendrites: int
    ) -> bool:
        """
        Validate that requested parameters match cached pattern.
        
        Args:
            pattern_key: The pattern key to validate
            n_source: Expected number of source neurons
            n_target: Expected number of target neurons
            projections_per_connection: Expected projections per connection
            total_dendrites: Expected total dendrites
            
        Returns:
            True if parameters match cached pattern, False otherwise
        """
        if pattern_key not in self._pattern_cache:
            return True  # No cached pattern, so any parameters are valid
        
        metadata = self._pattern_cache[pattern_key]['metadata']
        return (
            metadata['n_source'] == n_source and
            metadata['n_target'] == n_target and
            metadata['projections_per_connection'] == projections_per_connection and
            metadata['total_dendrites'] == total_dendrites
        )


class SharedConnectionPatternManager:
    """
    Global manager for sharing ConnectionPattern instances across multiple networks.
    
    This class provides a singleton-like interface for managing ConnectionPattern
    instances that can be shared across multiple neural network instances within
    the same training session.
    """
    
    _instances: Dict[int, ConnectionPattern] = {}
    
    @classmethod
    def get_pattern_manager(cls, seed: int = 1234, rng: Optional[np.random.Generator] = None) -> ConnectionPattern:
        """
        Get or create a ConnectionPattern manager for the given seed.
        
        Args:
            seed: Random seed for pattern generation
            rng: Optional external random number generator to use
            
        Returns:
            ConnectionPattern instance for the given seed
        """
        # If providing external RNG, don't cache (as it would break RNG sequence)
        if rng is not None:
            return ConnectionPattern(seed, rng)
        
        # Otherwise use cached instance
        if seed not in cls._instances:
            cls._instances[seed] = ConnectionPattern(seed)
        return cls._instances[seed]
    
    @classmethod
    def clear_all_managers(cls):
        """Clear all cached pattern managers."""
        cls._instances.clear()
    
    @classmethod
    def get_manager_info(cls) -> Dict[str, Any]:
        """
        Get information about all managed pattern managers.
        
        Returns:
            Dictionary containing information about all managers
        """
        info = {
            'num_managers': len(cls._instances),
            'managers': {}
        }
        
        for seed, manager in cls._instances.items():
            info['managers'][seed] = manager.get_cache_info()
        
        return info


# Example usage and compatibility functions for integration with existing code
def create_connection_pattern_for_network(
    seed: int,
    n_input: int,
    n_hidden: int,
    projections_per_dendrite: int,
    total_dendrites: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Convenience function to create connection patterns compatible with existing Network_parallel code.
    
    This function provides the same interface as the original pattern generation
    logic in cell.py, making it easy to integrate with existing networks.
    
    Args:
        seed: Random seed
        n_input: Number of input neurons (N_in)
        n_hidden: Number of hidden neurons (N_hidden)
        projections_per_dendrite: Number of projections per dendrite (num_proj_dend)
        total_dendrites: Total number of dendrites (total_dend)
        
    Returns:
        Tuple of (dendrite_indices, dendrite_locations) compatible with existing code
    """
    pattern_manager = SharedConnectionPatternManager.get_pattern_manager(seed)
    return pattern_manager.get_connection_pattern(
        pattern_key="input_to_hidden",
        n_source=n_input,
        n_target=n_hidden,
        projections_per_connection=projections_per_dendrite,
        total_dendrites=total_dendrites
    )
