# dnon_modules_squad.py
import json
import time
import logging
import numpy as np
import boto3
import torch
from scipy.special import softmax
from manifold_representation_StrategyQA import RiemannianManifold, EuclideanEmbedding
from memory_dynamics_StrategyQA import ShortTermMemory, LongTermMemory, DeepSubconsciousMemory
from information_geometry_StrategyQA import FisherInformationMatrix, NaturalGradient

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BaseAgent:
    def __init__(self, state_dim=256, temperature=0.02):
        self.MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
        self.temperature = temperature
        self.bedrock_runtime = None
        self.latency_stats = []
        
        # Add manifold representation
        self.manifold = RiemannianManifold(dim=state_dim)
        self.state = None
        self.state_dim = state_dim
        self.state_history = []
        
        # Add Fisher information matrix for natural gradient
        self.fisher_matrix = FisherInformationMatrix(dim=state_dim)
        
    def set_bedrock_client(self, bedrock_client):
        self.bedrock_runtime = bedrock_client

    def query(self, user_input):
        """Query the Claude model with the given input"""
        logger.info(f"Query sent to {self.__class__.__name__}: {user_input[:50]}...")
        
        # Start timing
        start_time = time.time()

        try:
            response = self.bedrock_runtime.invoke_model(
                modelId=self.MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 4090,
                    "messages": [{"role": "user", "content": [
                        {"type": "text", "text": f"{self.system_prompt}\n\n{user_input}"}
                    ]}],
                    "temperature": self.temperature,
                    "top_p": 1,
                })
            )
            
            response_body = json.loads(response['body'].read())
            response_text = response_body['content'][0]['text']
            
            # Calculate and record latency
            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            self.latency_stats.append(latency_ms)
            
            # Update state manifold with embedding of response
            embedding = self.generate_embedding(response_text)
            self.update_state(embedding)
            
            logger.info(f"Response from {self.__class__.__name__}: {response_text[:100]}... ({latency_ms:.2f}ms)")
            return response_text

        except Exception as e:
            logger.error(f"Error querying model: {e}")
            return f"An error occurred: {e}"

    def generate_embedding(self, text):
        """Generate embedding using Cohere model, handling large text inputs"""
        try:
            # Check if text is too long (using Cohere's 2048 limit)
            if len(text) > 2000:  # Set to 2000 to be safe
                # Split text into chunks of ~2000 chars
                chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
                logger.info(f"Text length {len(text)} chars, split into {len(chunks)} chunks")
                
                # Get embeddings for each chunk (limit to first 5 chunks)
                chunk_embeddings = []
                for i, chunk in enumerate(chunks[:5]):  # Only use first 5 chunks
                    try:
                        response = self.bedrock_runtime.invoke_model(
                            modelId="cohere.embed-multilingual-v3",
                            contentType="application/json",
                            body=json.dumps({"texts": [chunk], "input_type": "search_document"})
                        )
                        chunk_embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
                        chunk_embeddings.append(chunk_embedding)
                    except Exception as chunk_e:
                        logger.error(f"Error embedding chunk {i}: {chunk_e}")
                
                # Average the chunk embeddings
                if chunk_embeddings:
                    embedding = np.mean(chunk_embeddings, axis=0)
                else:
                    # Fallback to random if all chunks failed
                    embedding = np.random.rand(1024)
            else:
                # Standard case for shorter text
                response = self.bedrock_runtime.invoke_model(
                    modelId="cohere.embed-multilingual-v3",
                    contentType="application/json",
                    body=json.dumps({"texts": [text], "input_type": "search_document"})
                )
                embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
            
            # Handle dimension mismatch through chunking/processing
            if len(embedding) != self.state_dim:
                # Process in chunks to create a compressed representation
                chunks = len(embedding) // self.state_dim
                compressed_embedding = np.zeros(self.state_dim)
                
                for i in range(self.state_dim):
                    start_idx = i * chunks
                    end_idx = min(start_idx + chunks, len(embedding))
                    compressed_embedding[i] = np.mean(embedding[start_idx:end_idx])
                    
                return compressed_embedding
            
            return embedding
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            return np.random.rand(self.state_dim)
    
    def update_state(self, embedding):
        """Update agent state on the manifold"""
        if self.state is None:
            self.state = EuclideanEmbedding(embedding)
        else:
            # Project the embedding onto the manifold
            manifold_point = self.manifold.project(embedding)
            
            # Update state using information geometry
            self.state = self.manifold.geodesic(self.state, manifold_point, step_size=0.5)
        
        # Store state history for convergence analysis
        self.state_history.append(self.state.copy())
    
    def get_state(self):
        """Get current state vector"""
        if self.state is None:
            return np.zeros(self.state_dim)
        return self.state.coordinates

    def compute_fisher_information(self):
        """Compute Fisher information matrix based on state history"""
        if len(self.state_history) > 1:
            self.fisher_matrix.update([sh.coordinates for sh in self.state_history])
        return self.fisher_matrix.get_matrix()


class PerceptionAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized perception module for reading comprehension on the SQuAD dataset. Your task is to:

1. Parse the given context and question with extreme precision
2. Extract ALL key entities, concepts, and their relationships from both the context and question
3. Identify the type of information being requested:
   - FACTUAL DETAILS: specific facts mentioned in the context
   - ENTITY INFORMATION: details about people, places, organizations
   - TEMPORAL RELATIONSHIPS: when events occurred
   - CAUSAL RELATIONSHIPS: why something happened
   - DEFINITIONS: what something means or refers to

4. Structure the information to make answer extraction easier by:
   - Highlighting sentences that likely contain the answer
   - Noting connections between the question entities and context
   - Identifying potential answer spans in the context

Output a structured representation of the context and question that will aid in answering the question accurately.
DO NOT answer the question yourself. Focus ONLY on structuring it for precise reasoning.
"""


class MemoryAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02, 
                st_decay_rate=0.1, lt_consolidation_rate=0.01, ds_access_coef=0.001):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized memory module for reading comprehension on the SQuAD dataset. Your job is to:

1. Provide ALL relevant knowledge needed to answer the question accurately
2. Focus primarily on information from the provided context
3. Supplement with background knowledge only when necessary for clarification
4. For named entities in the context, provide relevant facts about them
5. For specialized terminology, provide brief definitions if they appear in the context
6. Identify connections between concepts in the question and those in the context
7. Highlight the most relevant passages that likely contain the answer

Present this information in clear, factual statements organized by topic.
DO NOT solve the problem or provide the final answer - only provide the relevant knowledge.
"""
        
        # Enhanced three-tier memory implementation
        self.stm = ShortTermMemory(capacity=20, decay_rate=st_decay_rate, dim=state_dim)
        self.ltm = LongTermMemory(capacity=100, consolidation_rate=lt_consolidation_rate, dim=state_dim)
        self.dsm = DeepSubconsciousMemory(access_coef=ds_access_coef, dim=state_dim)
        
    def process_with_memory(self, query, facts=None):
        """Process query using the enhanced three-tier memory system with optional facts injection"""
        # 1. Generate embedding for the query
        query_embedding = self.generate_embedding(query)
        
        # 2. Process with enhanced memory components
        stm_output = self.stm.retrieve(query_embedding)
        ltm_output = self.ltm.retrieve(query_embedding)
        dsm_output = self.dsm.retrieve(query_embedding)
        
        # 3. Consolidate memory (move from STM to LTM if important)
        importance_score = self.calculate_importance(query)
        self.ltm.consolidate(self.stm, query_embedding, importance_score)
        
        # 4. Update deep subconscious memory based on emotional salience
        emotional_salience = self.calculate_emotional_salience(query)
        self.dsm.update(query_embedding, ltm_output, emotional_salience)
        
        # 5. Update short-term memory
        self.stm.add(query, query_embedding)
        
        # Build enhanced memory context from the memory systems
        memory_context = f"""
        SHORT-TERM MEMORY (recent questions):
        {stm_output}

        LONG-TERM MEMORY (consolidated knowledge):
        {ltm_output}

        DEEP KNOWLEDGE (reading comprehension rules):
        {dsm_output}
        """
        
        # If facts are provided, include them in the context
        if facts and len(facts) > 0:
            facts_text = "\n".join([f"- {fact}" for fact in facts])
            memory_context += f"\n\nRELEVANT FACTS:\n{facts_text}"
        
        # Query the LLM with enhanced context
        enhanced_query = f"{query}\n\nUse the following knowledge:\n{memory_context}"
        return self.query(enhanced_query), memory_context
    
    def calculate_importance(self, query):
        """Calculate importance score for memory consolidation"""
        # Simple heuristic based on query complexity
        importance = min(0.9, 0.3 + 0.1 * len(query.split()))
        
        # Check for keywords indicating importance
        important_keywords = ["define", "who", "when", "where", "why", "how", "explain"]
        if any(keyword in query.lower() for keyword in important_keywords):
            importance += 0.2
            
        return min(1.0, importance)
    
    def calculate_emotional_salience(self, query):
        """Calculate emotional salience for deep memory updates"""
        # Simple heuristic - in a real system this would use sentiment analysis
        emotional_words = ["important", "significant", "critical", "crucial", 
                         "key", "essential", "primary", "main", "fundamental"]
        
        salience = 0.1  # Base salience
        for word in emotional_words:
            if word in query.lower():
                salience += 0.1
                
        return min(1.0, salience)


class ReasoningAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized reasoning module for reading comprehension on the SQuAD dataset. Your job is to:

1. Process the structured question and relevant knowledge with extreme precision
2. Apply methodical, step-by-step logical reasoning to extract the answer from the context
3. Focus on the specific information requested in the question
4. Locate the precise span of text in the context that answers the question
5. Verify that your answer is directly supported by the context
6. Ensure your answer is as concise as possible while fully answering the question

Follow this exact reasoning structure:
Step 1: Identify the type of information being requested
Step 2: Locate relevant passages in the context
Step 3: Extract candidate answer spans
Step 4: Verify each candidate against the question
Step 5: Select the most precise answer span

Final Answer: [Extracted answer span from the context]

Be extremely precise. Your answer should be the exact text span from the context that answers the question.
"""


class ExecutiveAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.01):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are an executive decision-making module that integrates outputs from three specialized modules for SQuAD questions:

1. Perception Module: How the question and context are structured
2. Memory Module: Relevant knowledge and facts
3. Reasoning Module: Step-by-step logical reasoning

Given the module weights, make the final determination by following this process:

1. Carefully review the context and question
2. Examine the reasoning process used to extract the answer
3. Verify that the answer is directly supported by the context
4. Ensure the answer is the most precise span of text that fully answers the question
5. If multiple answer candidates exist, select the one with the strongest support

Return ONLY the answer text, with no additional explanation. Your response should be the exact span of text from the context that answers the question.
"""
        
        # Initialize attention weights for modules
        self.module_attention_weights = np.array([0.33, 0.33, 0.33])  # P, M, R
        
    def update_attention_weights(self, routing_weights):
        """Update attention weights based on routing matrix"""
        # Extract relevant weights from routing matrix
        p_weight = routing_weights['perception']
        m_weight = routing_weights['memory']
        r_weight = routing_weights['reasoning']
        
        # Update module attention weights with exponential moving average
        alpha = 0.7  # EMA factor
        new_weights = np.array([p_weight, m_weight, r_weight])
        self.module_attention_weights = alpha * self.module_attention_weights + (1-alpha) * new_weights
        
        # Normalize to ensure they sum to 1
        self.module_attention_weights = self.module_attention_weights / np.sum(self.module_attention_weights)
        
        return self.module_attention_weights


# dnon_system_squad.py
import time
import json
import logging
import re
import matplotlib.pyplot as plt
import os
import numpy as np
from convergence_analysis_StrategyQA import ConvergenceAnalyzer
from dynamic_router_StrategyQA import DynamicRouter

logger = logging.getLogger(__name__)

class DNONSystem_SQuAD:
    def __init__(self, session, opensearch_clients=None, model_pool=None):
        """
        Initialize the DNON system for SQuAD.
        """
        self.bedrock_runtime = session.client('bedrock-runtime')
        self.opensearch_client = None  # We're not using OpenSearch
        
        # Initialize modules
        from dnon_modules_squad import PerceptionAgent, MemoryAgent, ReasoningAgent, ExecutiveAgent
        self.perception = PerceptionAgent(state_dim=256)
        self.memory = MemoryAgent(state_dim=256, st_decay_rate=0.1, 
                                lt_consolidation_rate=0.01, ds_access_coef=0.001)
        self.reasoning = ReasoningAgent(state_dim=256)
        self.executive = ExecutiveAgent(state_dim=256)
        
        # Set bedrock client for all agents
        self.perception.set_bedrock_client(self.bedrock_runtime)
        self.memory.set_bedrock_client(self.bedrock_runtime)
        self.reasoning.set_bedrock_client(self.bedrock_runtime)
        self.executive.set_bedrock_client(self.bedrock_runtime)
        
        # Initialize router
        self.router = DynamicRouter(self.bedrock_runtime, beta=2.0)
        
        # Initialize convergence analyzer
        self.convergence_analyzer = ConvergenceAnalyzer()
        
        # Metrics tracking
        self.results = []
        
        # Create optimal states for convergence analysis
        self.optimal_states = {
            'perception': np.random.randn(256),  # Placeholder for optimal states
            'memory': np.random.randn(256),
            'reasoning': np.random.randn(256),
            'executive': np.random.randn(256)
        }

    def solve(self, question, context=None, facts=None, decomposition=None, expected_answer=None):
        """Process a question through the DNON system with required context and optional facts/decomposition"""
        print(f"Processing question: {question}")
        
        # Prepare input with context if provided
        perception_input = question
        if context:
            perception_input = f"Context: {context}\n\nQuestion: {question}"
        
        # Step 1: Perception module processes the question and context
        perception_start = time.time()
        perception_output = self.perception.query(perception_input)
        perception_time = time.time() - perception_start
        print(f"✓ Perception module (took {perception_time:.2f}s)")
        
        # Step 2: Memory module retrieves relevant knowledge
        memory_start = time.time()
        memory_input = perception_output
        if context:
            memory_input += f"\n\nFull Context:\n{context}"
        memory_output, memory_context = self.memory.process_with_memory(memory_input, facts)
        memory_time = time.time() - memory_start
        print(f"✓ Memory module (took {memory_time:.2f}s)")
        
        # Step 3: Reasoning module solves the problem
        reasoning_start = time.time()
        reasoning_input = f"Structured Question:\n{perception_output}\n\nRelevant Knowledge:\n{memory_output}"
        if context:
            reasoning_input += f"\n\nContext Passage:\n{context}"
        reasoning_output = self.reasoning.query(reasoning_input)
        reasoning_time = time.time() - reasoning_start
        print(f"✓ Reasoning module (took {reasoning_time:.2f}s)")
        
        # Step 4: Compute routing weights
        outputs = {
            'perception': perception_output,
            'memory': memory_output,
            'reasoning': reasoning_output
        }
        routing_weights = self.router.compute_routing_matrix(outputs)
        
        # Step 5: Executive module combines outputs with routing weights
        executive_start = time.time()
        # Update executive module attention weights
        self.executive.update_attention_weights(routing_weights)
        
        executive_input = f"""
Module weights:
- Perception: {routing_weights['perception']:.2f}
- Memory: {routing_weights['memory']:.2f}
- Reasoning: {routing_weights['reasoning']:.2f}

Perception output:
{perception_output}

Memory output:
{memory_output}

Reasoning output:
{reasoning_output}
"""
        if context:
            executive_input += f"\n\nContext Passage:\n{context}"
            
        executive_output = self.executive.query(executive_input)
        executive_time = time.time() - executive_start
        print(f"✓ Executive module (took {executive_time:.2f}s)")
        
        # Extract answer from executive output
        final_answer = executive_output.strip()
        
        # For training, check if answer is correct
        if expected_answer and hasattr(self, 'training_mode') and self.training_mode:
            from squad_dataset import check_answer_correctness
            is_correct = check_answer_correctness(final_answer, expected_answer)
        else:
            is_correct = None
            
        # Step 6: Analyze convergence
        current_states = {
            'perception': self.perception.get_state(),
            'memory': self.memory.get_state(),
            'reasoning': self.reasoning.get_state(),
            'executive': self.executive.get_state()
        }
        
        # Compute Lyapunov function
        lyapunov_value = self.convergence_analyzer.compute_lyapunov_function(
            current_states, self.optimal_states)
        
        # Check convergence
        is_converged = self.convergence_analyzer.check_convergence()
        
        # Record result
        result = {
            'problem': question,
            'context': context,
            'perception_output': perception_output,
            'memory_output': memory_output,
            'memory_context': memory_context,
            'reasoning_output': reasoning_output,
            'routing_weights': routing_weights,
            'executive_output': executive_output,
            'final_answer': final_answer,
            'is_correct': is_correct,
            'lyapunov_value': lyapunov_value,
            'is_converged': is_converged,
            'timings': {
                'perception': perception_time,
                'memory': memory_time,
                'reasoning': reasoning_time,
                'executive': executive_time,
                'total': perception_time + memory_time + reasoning_time + executive_time
            }
        }
        
        self.results.append(result)
        return result
        
    def _get_current_states(self):
        """Get current states of all modules"""
        return {
            'perception': self.perception.get_state(),
            'memory': self.memory.get_state(),
            'reasoning': self.reasoning.get_state(),
            'executive': self.executive.get_state()
        }

# convergence_analysis_trainable_squad.py
from convergence_analysis_StrategyQA import ConvergenceAnalyzer
import numpy as np
import logging

logger = logging.getLogger(__name__)

class AdaptiveConvergenceAnalyzer_SQuAD(ConvergenceAnalyzer):
    """Enhanced convergence analyzer with forced convergence capabilities for SQuAD"""
    
    def __init__(self, force_convergence=False, max_iterations=20):
        super().__init__()
        self.force_convergence = force_convergence
        self.max_iterations = max_iterations
        self.iteration_count = 0
        logger.info(f"Initialized AdaptiveConvergenceAnalyzer for SQuAD, force_convergence={force_convergence}")
        
    def update_optimal_states(self, current_states, original_optimal_states):
        """More gradual adaptation of optimal states for reading comprehension tasks"""
        if not self.force_convergence:
            return original_optimal_states
            
        self.iteration_count += 1
        adapted_states = {}
        
        # More conservative adaptation rate (reduced from 0.1 to 0.05)
        adaptation_rate = 0.05 * (1.0 - min(0.8, self.iteration_count / self.max_iterations))
        
        for module, state in current_states.items():
            if module in original_optimal_states:
                # Move optimal state slightly toward current state
                adapted_states[module] = (
                    (1 - adaptation_rate) * np.array(original_optimal_states[module]) + 
                    adaptation_rate * np.array(state)
                )
            else:
                adapted_states[module] = original_optimal_states.get(module, state)
                    
        logger.info(f"SQuAD: Updated optimal states with adaptation rate {adaptation_rate:.4f}")
        return adapted_states
    
    def check_convergence(self, tolerance=1e-3, window=5):
        """Enhanced convergence check for SQuAD tasks"""
        # Natural convergence check
        natural_convergence = super().check_convergence(tolerance, window)
        
        # Force convergence only after more iterations - reading comprehension may need more steps
        if self.force_convergence and self.iteration_count >= self.max_iterations * 1.5:
            logger.info(f"SQuAD: Forcing convergence after {self.iteration_count} iterations")
            return True
                
        return natural_convergence


# convergence_analysis_StrategyQA.py
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.linalg import eigh

class ConvergenceAnalyzer:
    """Analyzes convergence properties of the DNON system"""
    
    def __init__(self):
        self.lyapunov_history = []
        self.stability_thresholds = {
            'stable': -0.1,     # Eigenvalues below this are considered stable
            'neutral': 0.01,    # Eigenvalues around zero are considered neutral
            'unstable': 0.1     # Eigenvalues above this are considered unstable
        }
    
    def compute_lyapunov_function(self, current_states, optimal_states):
        """Compute Lyapunov function V(S) = sum_i ||S_i - S_i*||^2"""
        lyapunov = 0
        for module, state in current_states.items():
            if module in optimal_states:
                # Compute squared distance from optimal state
                state_array = np.array(state)
                optimal_array = np.array(optimal_states[module])
                
                # Ensure same shape
                min_dim = min(len(state_array), len(optimal_array))
                diff = state_array[:min_dim] - optimal_array[:min_dim]
                
                lyapunov += np.sum(diff**2)
        
        self.lyapunov_history.append(lyapunov)
        return lyapunov
    
    def check_convergence(self, tolerance=1e-3, window=5):
        """Check if the system has converged based on Lyapunov history"""
        if len(self.lyapunov_history) < window + 1:
            return False
            
        # Check if Lyapunov function is consistently decreasing
        is_decreasing = all(self.lyapunov_history[-i-1] - self.lyapunov_history[-i] > -tolerance 
                           for i in range(1, window))
        
        # Check if value is below threshold
        is_small = self.lyapunov_history[-1] < tolerance
        
        # Check if change is minimal
        recent_changes = [abs(self.lyapunov_history[-i] - self.lyapunov_history[-i-1]) 
                         for i in range(1, window)]
        is_stable = all(change < tolerance for change in recent_changes)
        
        return is_small and is_stable
    
    def compute_lipschitz_constants(self, module_dynamics):
        """Compute Lipschitz constants for module dynamics functions"""
        lipschitz_constants = {}
        
        for module, dynamics in module_dynamics.items():
            # Sample points from state space
            samples = 10
            sample_points = [np.random.randn(100) for _ in range(samples)]
            
            # Compute Lipschitz constant as max ratio of output to input distance
            max_ratio = 0
            for i in range(samples):
                for j in range(i+1, samples):
                    x_i = sample_points[i]
                    x_j = sample_points[j]
                    
                    # Apply dynamics
                    y_i = dynamics(x_i)
                    y_j = dynamics(x_j)
                    
                    # Compute distances
                    input_dist = np.linalg.norm(x_i - x_j)
                    output_dist = np.linalg.norm(y_i - y_j)
                    
                    if input_dist > 0:
                        ratio = output_dist / input_dist
                        max_ratio = max(max_ratio, ratio)
            
            lipschitz_constants[module] = max_ratio
        
        return lipschitz_constants
    
    def analyze_stability(self, jacobian_matrix):
        """Analyze stability using eigenvalues of the Jacobian"""
        # Compute eigenvalues
        eigenvalues = eigh(jacobian_matrix, eigvals_only=True)
        
        # Analyze stability
        max_real_part = max(eigenvalues).real
        min_real_part = min(eigenvalues).real
        
        # Determine stability category
        if max_real_part < self.stability_thresholds['stable']:
            stability = "Stable"
        elif max_real_part < self.stability_thresholds['neutral']:
            stability = "Marginally Stable"
        else:
            stability = "Unstable"
        
        return {
            'eigenvalues': eigenvalues,
            'max_eigenvalue': max_real_part,
            'min_eigenvalue': min_real_part,
            'stability': stability
        }
    
    def estimate_jacobian(self, dynamics_function, state, epsilon=1e-6):
        """Estimate Jacobian matrix of dynamics function at given state"""
        n = len(state)
        jacobian = np.zeros((n, n))
        
        # Evaluate function at current state
        f_x = dynamics_function(state)
        
        # Compute partial derivatives
        for i in range(n):
            # Create perturbed state
            perturbed_state = state.copy()
            perturbed_state[i] += epsilon
            
            # Evaluate function at perturbed state
            f_perturbed = dynamics_function(perturbed_state)
            
            # Estimate partial derivatives
            jacobian[:, i] = (f_perturbed - f_x) / epsilon
        
        return jacobian
    
    def plot_convergence(self, title="DNON System Convergence"):
        """Plot Lyapunov function convergence over iterations"""
        plt.figure(figsize=(10, 6))
        plt.plot(self.lyapunov_history)
        plt.title(title)
        plt.xlabel("Iteration")
        plt.ylabel("Lyapunov Function Value")
        plt.yscale('log')  # Log scale to better visualize convergence
        plt.grid(True)
        plt.tight_layout()
        plt.savefig("dnon_convergence.png", dpi=300)
        plt.show()
        
    def verify_convergence_conditions(self, lipschitz_constants, routing_matrix):
        """Verify theoretical convergence conditions from DNON paper"""
        # Check Lipschitz constants condition (L_i < 0 for contractive)
        lipschitz_ok = all(L < 0 for L in lipschitz_constants.values())
        
        # Check routing matrix condition (sum_j Omega_ji <= 1)
        routing_ok = all(sum(row) <= 1.01 for row in routing_matrix)  # Allow slight numerical error
        
        # Overall convergence condition
        converges = lipschitz_ok and routing_ok
        
        return {
            'converges': converges,
            'lipschitz_ok': lipschitz_ok,
            'routing_ok': routing_ok,
            'lipschitz_constants': lipschitz_constants
        }

    def verify_lipschitz_conditions(self, module_dynamics_functions, states):
        """
        Verify the Lipschitz condition that each F_i is L_i-Lipschitz with L_i < 0
        This is a necessary condition for system convergence
        """
        lipschitz_constants = {}
        condition_satisfied = True
        
        for module, dynamics_fn in module_dynamics_functions.items():
            # Estimate Lipschitz constant by sampling
            samples = 20
            max_ratio = -float('inf')
            
            # Generate random state pairs
            for _ in range(samples):
                x = np.random.randn(100)  # Sample state
                y = np.random.randn(100)  # Another sample state
                
                # Apply dynamics function
                f_x = dynamics_fn(x)
                f_y = dynamics_fn(y)
                
                # Compute Lipschitz ratio
                numerator = np.linalg.norm(f_x - f_y)
                denominator = np.linalg.norm(x - y)
                
                if denominator > 0:
                    ratio = numerator / denominator
                    max_ratio = max(max_ratio, ratio)
            
            # Store Lipschitz constant
            lipschitz_constants[module] = max_ratio
            
            # Check if condition is satisfied
            if max_ratio >= 0:
                condition_satisfied = False
        
        return lipschitz_constants, condition_satisfied

    def verify_routing_condition(self, routing_matrix):
        """
        Verify the routing condition that sum_j Omega_ji(t) <= 1
        This is a necessary condition for system convergence
        """
        condition_satisfied = True
        
        # Check if routing matrix is a dictionary (per-module weights)
        if isinstance(routing_matrix, dict):
            total_weight = sum(routing_matrix.values())
            if total_weight > 1.01:  # Allow small numerical error
                condition_satisfied = False
        else:
            # Assume it's a 2D matrix with rows representing sources
            for row in routing_matrix:
                if sum(row) > 1.01:  # Allow small numerical error
                    condition_satisfied = False
                    break
        
        return condition_satisfied

    def verify_convergence_theorem(self, module_dynamics, routing_matrix, noise_variance):
        """
        Verify the DNON Convergence Theorem (Theorem 5)
        
        Requirements:
        1. Each F_i is L_i-Lipschitz with L_i < 0 (contractive)
        2. The noise terms are bounded with covariance <= sigma^2 * I
        3. Routing coefficients satisfy sum_j Omega_ji(t) <= 1
        """
        # Check Lipschitz condition
        lipschitz_constants, lipschitz_ok = self.verify_lipschitz_conditions(module_dynamics, None)
        
        # Check routing condition
        routing_ok = self.verify_routing_condition(routing_matrix)
        
        # Check noise condition (always satisfied with bounded noise)
        noise_ok = noise_variance > 0 and noise_variance < float('inf')
        
        # Calculate alpha for the convergence bound
        if lipschitz_ok:
            max_lipschitz = max([abs(L) for L in lipschitz_constants.values()])
            alpha = max_lipschitz
        else:
            alpha = None
        
        # Calculate the convergence neighborhood radius
        if lipschitz_ok and noise_ok and alpha > 0:
            radius = np.sqrt(noise_variance / alpha)
        else:
            radius = None
        
        # Overall convergence condition
        converges = lipschitz_ok and routing_ok and noise_ok
        
        return {
            'converges': converges,
            'lipschitz_ok': lipschitz_ok,
            'routing_ok': routing_ok,
            'noise_ok': noise_ok,
            'alpha': alpha,
            'radius': radius,
            'lipschitz_constants': lipschitz_constants
        }

class LyapunovOptimizer:
    """Optimizes system parameters using Lyapunov stability analysis"""
    
    def __init__(self, dynamics_model, initial_params):
        self.dynamics_model = dynamics_model
        self.params = initial_params
        self.analyzer = ConvergenceAnalyzer()
        
    def optimize_for_stability(self, target_stability=-0.1):
        """Optimize parameters to achieve target stability"""
        # Define objective function for optimization
        def objective(params):
            # Set parameters in model
            self.dynamics_model.set_params(params)
            
            # Estimate Jacobian at equilibrium point
            equilibrium = self.dynamics_model.find_equilibrium()
            jacobian = self.analyzer.estimate_jacobian(
                self.dynamics_model.step, 
                equilibrium
            )
            
            # Compute eigenvalues
            eigenvalues = eigh(jacobian, eigvals_only=True)
            
            # Objective: minimize distance of max eigenvalue from target stability
            return (max(eigenvalues).real - target_stability)**2
        
        # Run optimization
        result = minimize(
            objective, 
            self.params,
            method='L-BFGS-B',
            bounds=[(0.01, 10.0) for _ in range(len(self.params))]
        )
        
        # Update parameters with optimal values
        self.params = result.x
        self.dynamics_model.set_params(self.params)
        
        return {
            'optimal_params': self.params,
            'objective_value': result.fun,
            'success': result.success,
            'message': result.message
        }
    
    def compute_stability_region(self, param_ranges, param_names=None):
        """Compute stability region in parameter space"""
        # Default parameter names if not provided
        if param_names is None:
            param_names = [f'p{i}' for i in range(len(param_ranges))]
        
        # Check dimensions
        if len(param_ranges) > 2:
            print("Stability region visualization only supported for 1-2 parameters")
            param_ranges = param_ranges[:2]
            param_names = param_names[:2]
        
        if len(param_ranges) == 1:
            # 1D parameter sweep
            param_values = np.linspace(param_ranges[0][0], param_ranges[0][1], 100)
            stability_values = np.zeros_like(param_values)
            
            for i, param in enumerate(param_values):
                # Set parameter
                params = self.params.copy()
                params[0] = param
                self.dynamics_model.set_params(params)
                
                # Find equilibrium and analyze stability
                equilibrium = self.dynamics_model.find_equilibrium()
                jacobian = self.analyzer.estimate_jacobian(
                    self.dynamics_model.step, 
                    equilibrium
                )
                eigenvalues = eigh(jacobian, eigvals_only=True)
                stability_values[i] = max(eigenvalues).real
            
            # Plot stability
            plt.figure(figsize=(10, 6))
            plt.plot(param_values, stability_values)
            plt.axhline(y=0, color='r', linestyle='--', label='Stability Boundary')
            plt.xlabel(param_names[0])
            plt.ylabel('Max Eigenvalue Real Part')
            plt.title('Stability Analysis')
            plt.legend()
            plt.grid(True)
            plt.savefig("stability_region_1d.png", dpi=300)
            plt.show()
            
        else:
            # 2D parameter sweep
            p1_values = np.linspace(param_ranges[0][0], param_ranges[0][1], 20)
            p2_values = np.linspace(param_ranges[1][0], param_ranges[1][1], 20)
            P1, P2 = np.meshgrid(p1_values, p2_values)
            stability_values = np.zeros_like(P1)
            
            for i in range(len(p1_values)):
                for j in range(len(p2_values)):
                    # Set parameters
                    params = self.params.copy()
                    params[0] = p1_values[i]
                    params[1] = p2_values[j]
                    self.dynamics_model.set_params(params)
                    
                    # Find equilibrium and analyze stability
                    equilibrium = self.dynamics_model.find_equilibrium()
                    jacobian = self.analyzer.estimate_jacobian(
                        self.dynamics_model.step, 
                        equilibrium
                    )
                    eigenvalues = eigh(jacobian, eigvals_only=True)
                    stability_values[j, i] = max(eigenvalues).real
            
            # Plot stability region
            plt.figure(figsize=(10, 8))
            contour = plt.contourf(P1, P2, stability_values, 20, cmap='coolwarm')
            plt.colorbar(contour, label='Max Eigenvalue Real Part')
            plt.contour(P1, P2, stability_values, levels=[0], colors='k')
            plt.xlabel(param_names[0])
            plt.ylabel(param_names[1])
            plt.title('Stability Region')
            plt.tight_layout()
            plt.savefig("stability_region_2d.png", dpi=300)
            plt.show()
            
        return stability_values

# dataset_utils_squad.py
import numpy as np
import os
import json
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def split_dataset_squad(data, train_size=14, dev_size=3, test_size=3, seed=42):
    """Split SQuAD dataset into train/dev/test with specified sizes"""
    np.random.seed(seed)
    
    # Get all data
    ids = data['ids']
    titles = data['titles']
    contexts = data['contexts']
    questions = data['questions']
    answers = data['answers']
    
    indices = np.random.permutation(len(questions))
    
    # Ensure we have enough data
    total_needed = train_size + dev_size + test_size
    if len(questions) < total_needed:
        raise ValueError(f"Dataset too small: {len(questions)} < {total_needed}")
    
    # Split according to specified sizes
    train_indices = indices[:train_size]
    dev_indices = indices[train_size:train_size+dev_size]
    test_indices = indices[train_size+dev_size:train_size+dev_size+test_size]
    
    # Create train dataset
    train_data = {
        'ids': [ids[i] for i in train_indices],
        'titles': [titles[i] for i in train_indices],
        'contexts': [contexts[i] for i in train_indices],
        'questions': [questions[i] for i in train_indices],
        'answers': [answers[i] for i in train_indices]
    }
    
    # Create dev dataset
    dev_data = {
        'ids': [ids[i] for i in dev_indices],
        'titles': [titles[i] for i in dev_indices],
        'contexts': [contexts[i] for i in dev_indices],
        'questions': [questions[i] for i in dev_indices],
        'answers': [answers[i] for i in dev_indices]
    }
    
    # Create test dataset
    test_data = {
        'ids': [ids[i] for i in test_indices],
        'titles': [titles[i] for i in test_indices],
        'contexts': [contexts[i] for i in test_indices],
        'questions': [questions[i] for i in test_indices],
        'answers': [answers[i] for i in test_indices]
    }
    
    logger.info(f"SQuAD dataset split: {len(train_data['questions'])} train, {len(dev_data['questions'])} dev, {len(test_data['questions'])} test")
    return train_data, dev_data, test_data

# dnon_system_trainable_squad.py
from dnon_system_squad import DNONSystem_SQuAD
from dynamic_router_trainable_squad import TrainableDynamicRouter_SQuAD
from convergence_analysis_trainable_squad import AdaptiveConvergenceAnalyzer_SQuAD
from squad_dataset import check_answer_correctness
import numpy as np
import os
import time
import logging
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)

class TrainableDNONSystem_SQuAD(DNONSystem_SQuAD):
    """Enhanced DNON System with training capabilities for SQuAD"""
    
    def __init__(self, session, force_convergence=False):
        super().__init__(session)
        
        # Replace standard router with trainable version
        self.router = TrainableDynamicRouter_SQuAD(self.bedrock_runtime, beta=1.5)
        
        # Replace convergence analyzer - with longer max_iterations
        self.convergence_analyzer = AdaptiveConvergenceAnalyzer_SQuAD(
            force_convergence=force_convergence, 
            max_iterations=50
        )
        
        # Training state
        self.training_mode = False
        self.train_results = []
        self.dev_results = []
        self.test_results = []
        
        logger.info("Initialized TrainableDNONSystem for SQuAD")
    
    def initialize_from_baseline(self, baseline_model):
        """Start with baseline model that already works well - enhanced version"""
        # Copy router settings
        self.router.beta = baseline_model.router.beta * 0.9  # Start slightly lower
        
        # Copy optimal states from baseline model
        self.optimal_states = {
            'perception': baseline_model.perception.get_state(),
            'memory': baseline_model.memory.get_state(),
            'reasoning': baseline_model.reasoning.get_state(),
            'executive': baseline_model.executive.get_state()
        }
        
        # Copy memory systems (crucial addition)
        if hasattr(baseline_model.memory, 'stm') and hasattr(self.memory, 'stm'):
            self.memory.stm.items = baseline_model.memory.stm.items.copy() 
        if hasattr(baseline_model.memory, 'ltm') and hasattr(self.memory, 'ltm'):
            self.memory.ltm.items = baseline_model.memory.ltm.items.copy()
        if hasattr(baseline_model.memory, 'dsm') and hasattr(self.memory, 'dsm'):
            if hasattr(baseline_model.memory.dsm, 'patterns'):
                self.memory.dsm.patterns = baseline_model.memory.dsm.patterns.copy()
        
        # Copy state histories if available
        if hasattr(baseline_model.perception, 'state_history'):
            self.perception.state_history = baseline_model.perception.state_history.copy()
        if hasattr(baseline_model.memory, 'state_history'):
            self.memory.state_history = baseline_model.memory.state_history.copy()
        if hasattr(baseline_model.reasoning, 'state_history'):
            self.reasoning.state_history = baseline_model.reasoning.state_history.copy()
        
        logger.info(f"Initialized from baseline model with beta={self.router.beta:.3f} and copied memory systems")
        return self

    def set_training_mode(self, is_training):
        """Set whether the system is in training mode"""
        self.training_mode = is_training
        return self
    
    def solve(self, question, context=None, facts=None, decomposition=None, expected_answer=None):
        """Process a question with optional training updates and regularization"""
        # Regular solve
        result = super().solve(question, context=context, facts=facts, 
                             decomposition=decomposition, expected_answer=expected_answer)
        
        # Apply routing regularization
        result['routing_weights'] = self.regularize_routing(result['routing_weights'])
        
        # If in training mode with expected answer, update router
        if self.training_mode and expected_answer is not None:
            # Check if answer is correct
            is_correct = check_answer_correctness(result['final_answer'], expected_answer)
            result['is_correct'] = is_correct
            
            # Update router parameters
            self.router.update_parameters(is_correct)
            
            # Force convergence if enabled, but with a more gradual approach
#            if hasattr(self.convergence_analyzer, 'force_convergence') and self.convergence_analyzer.force_convergence:
#                current_states = self._get_current_states()
#                self.optimal_states = self.convergence_analyzer.update_optimal_states(
#                    current_states, self.optimal_states
#                )

            if hasattr(self.convergence_analyzer, 'force_convergence') and self.convergence_analyzer.force_convergence:
                current_states = self._get_current_states()
                self.optimal_states = self.convergence_analyzer.update_optimal_states(
                    current_states, self.optimal_states,
                    routing_weights=result['routing_weights']  # Add routing weights
                )

        return result

    def regularize_routing(self, routing_matrix):
        """Add regularization to prevent drastic changes"""
        # Apply regularization if we have history
        if len(self.router.routing_history) > 0:
            prev_routing = self.router.routing_history[-1]
            
            # Reduce smoothing factor during training to allow more change
            alpha = 0.3 if self.training_mode else 0.7  # Reduced from 0.7
            
            for module in routing_matrix:
                if module in prev_routing:
                    routing_matrix[module] = alpha * routing_matrix[module] + (1-alpha) * prev_routing[module]
        
        # Ensure minimum weight per module
        min_weight = 0.15
        for module in routing_matrix:
            routing_matrix[module] = max(min_weight, routing_matrix[module])
            
        # Re-normalize
        total = sum(routing_matrix.values())
        for module in routing_matrix:
            routing_matrix[module] /= total
        
        return routing_matrix

    def _process_dataset(self, data, update_router=False):
        results = []
        initial_weights = None
        questions = data['questions']
        answers = data['answers']
        facts = data.get('facts', [[] for _ in questions])
        decompositions = data.get('decompositions', [[] for _ in questions])
        
        for i, (question, expected, fact_list, decomp_list) in enumerate(zip(questions, answers, facts, decompositions)):
            logger.info(f"Processing question {i+1}/{len(questions)}")
            
            # Process with expected answer for training, or without for evaluation
            if update_router:
                result = self.solve(question, facts=fact_list, decomposition=decomp_list, expected_answer=expected)
                
                # Check if weights are too symmetric
                if 'routing_weights' in result:
                    weights = result['routing_weights']
                    values = list(weights.values())
                    
                    # Store initial weights for comparison
                    if initial_weights is None:
                        initial_weights = weights.copy()
                    elif i > len(questions) // 3:  # After first third of questions
                        # Check if weights haven't changed much
                        weight_diff = sum(abs(weights[k] - initial_weights[k]) for k in weights)
                        if weight_diff < 0.08:  # Slightly higher threshold for StrategyQA
                            logger.warning(f"StrategyQA: Weights not changing significantly! Applying exploration.")
                            
                            if hasattr(self.router, 'force_asymmetry'):
                                self.router.force_asymmetry()
                                
                                # Reprocess the question with new weights
                                result = self.solve(question, facts=fact_list, decomposition=decomp_list, expected_answer=expected)
            else:
                result = self.solve(question, facts=fact_list, decomposition=decomp_list)
                # Add correctness after processing
                result['is_correct'] = self._check_answer_correctness(result['final_answer'], expected)
            
            # Add additional info to result
            result['question_text'] = question
            result['expected'] = expected
            
            results.append(result)
        
        return results


#    def regularize_routing(self, routing_matrix):
#        """Add regularization to prevent drastic changes"""
        # Apply regularization if we have history
#        if len(self.router.routing_history) > 0:
#            prev_routing = self.router.routing_history[-1]
            
            # Smooth transition with previous weights
#            alpha = 0.7  # Smoothing factor
#            for module in routing_matrix:
#                if module in prev_routing:
#                    routing_matrix[module] = alpha * routing_matrix[module] + (1-alpha) * prev_routing[module]
        
        # Ensure minimum weight per module
#        min_weight = 0.15
#        for module in routing_matrix:
#            routing_matrix[module] = max(min_weight, routing_matrix[module])
            
        # Re-normalize
#        total = sum(routing_matrix.values())
#        for module in routing_matrix:
#            routing_matrix[module] /= total
        
#        return routing_matrix
        
#    def _process_dataset(self, data, update_router=False):
#        """Process a SQuAD dataset batch with or without router updates"""
#        results = []
#        questions = data['questions']
#        contexts = data['contexts']
#        answers = data['answers']
        
#        for i, (question, context, expected) in enumerate(zip(questions, contexts, answers)):
#            logger.info(f"Processing question {i+1}/{len(questions)}")
            
            # Process with expected answer for training, or without for evaluation
#            if update_router:
#                result = self.solve(question, context=context, expected_answer=expected)
#            else:
#                result = self.solve(question, context=context)
                # Add correctness after processing
#                result['is_correct'] = check_answer_correctness(result['final_answer'], expected)
            
            # Add additional info to result
#            result['question_text'] = question
#            result['context_text'] = context
#            result['expected'] = expected
            
#            results.append(result)
        
#        return results

    def train(self, train_data, dev_data, epochs=3):
        """Enhanced training with performance safeguards"""
        # Store training results
        history = []
        
        # Store initial module weights for regularization
        initial_weights = {
            'perception': 0.33,  # Default weight values
            'memory': 0.33,
            'reasoning': 0.33
        }
        
        # Capture actual weights if available
        if hasattr(self.router, 'routing_history') and self.router.routing_history:
            initial_weights = self.router.routing_history[-1].copy()
        
        # Store initial beta
        initial_beta = self.router.beta
        best_beta = initial_beta
        
        # Track best performance
        best_dev_acc = 0.0
        best_state = None
        
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            
            # Process training examples
            train_correct = 0
            train_total = 0
            
            # Use only a subset of training data per epoch
            train_subset = min(len(train_data['questions']), 30)
            
            for i in range(train_subset):
                question = train_data['questions'][i]
                context = train_data['contexts'][i]
                expected = train_data['answers'][i]
                
                # Process with current model
                result = self.solve(question, context=context)
                
                # Check correctness
                is_correct = check_answer_correctness(result['final_answer'], expected)
                if is_correct:
                    train_correct += 1
                train_total += 1
                
                # Update with regularization
                self.router.update_parameters_with_regularization(is_correct, initial_weights)
            
            # Evaluate on dev set
            dev_correct = 0
            dev_total = 0
            
            dev_subset = min(len(dev_data['questions']), 20)
            
            for i in range(dev_subset):
                question = dev_data['questions'][i]
                context = dev_data['contexts'][i]
                expected = dev_data['answers'][i]
                
                result = self.solve(question, context=context)
                is_correct = check_answer_correctness(result['final_answer'], expected)
                
                if is_correct:
                    dev_correct += 1
                dev_total += 1
            
            # Calculate accuracies
            train_acc = train_correct / train_total if train_total > 0 else 0
            dev_acc = dev_correct / dev_total if dev_total > 0 else 0
            
            print(f"Epoch {epoch+1} - Train acc: {train_acc:.2%}, Dev acc: {dev_acc:.2%}")
            
            # Save best model
            if dev_acc > best_dev_acc:
                best_dev_acc = dev_acc
                best_beta = self.router.beta
                # You would save more state here in a full implementation
            
            # Early stopping if performance is degrading
            if epoch > 0 and dev_acc < 0.7 * best_dev_acc:
                print(f"Performance degrading significantly. Reverting to best state.")
                self.router.beta = best_beta
                # You would restore more state here
                break
            
            history.append({
                'epoch': epoch + 1,
                'train_acc': train_acc,
                'dev_acc': dev_acc,
                'beta': self.router.beta
            })
        
        # At the end of training, ensure we use the best configuration
        self.router.beta = best_beta
        
        return history
        
#    def train(self, train_data, dev_data, num_epochs=3):
#        """Train with early stopping"""
#        training_history = {'train_accuracy': [], 'dev_accuracy': [], 'beta_values': []}
        
        # Keep track of best model
#        best_dev_accuracy = 0
#        patience = 2  # Stop after no improvement for 2 epochs
#        no_improvement_count = 0
#        best_state = None
        
#        for epoch in range(num_epochs):
#            logger.info(f"=== Epoch {epoch+1}/{num_epochs} ===")
            
            # Training phase
#            self.set_training_mode(True)
#New
            # Add after self.set_training_mode(True):
            # Add epoch-specific exploration for StrategyQA
#            if epoch == 0:
                # First epoch: Push beta higher to emphasize differences
#                old_beta = self.router.beta
#                self.router.beta = min(4.0, self.router.beta * 1.8)  # Higher multiplier for StrategyQA
#                logger.info(f"First epoch exploration: beta {old_beta:.2f} → {self.router.beta:.2f}")
#            elif epoch == 1:
                # Second epoch: Emphasize reasoning module for StrategyQA
#                if hasattr(self.router, 'force_asymmetry'):
                    # Directly inject reasoning-biased weights for StrategyQA
#                    self.router.force_asymmetry()  # This creates reasoning-heavy weights
#                    logger.info("Second epoch: Forced reasoning-heavy weights")
#end
#            train_results = self._process_dataset(train_data, update_router=True)
#            train_accuracy = sum(r['is_correct'] for r in train_results) / len(train_data['questions'])
#            logger.info(f"Training accuracy: {train_accuracy:.2%}")
            
            # Validation phase
#            self.set_training_mode(False)
#            dev_results = self._process_dataset(dev_data, update_router=False)
#            dev_accuracy = sum(r['is_correct'] for r in dev_results) / len(dev_data['questions'])
#            logger.info(f"Dev accuracy: {dev_accuracy:.2%}")
            
            # Record history
#            training_history['train_accuracy'].append(train_accuracy)
#            training_history['dev_accuracy'].append(dev_accuracy)
#            training_history['beta_values'].append(self.router.beta)
#            logger.info(f"Current router beta: {self.router.beta:.3f}")
            
            # Early stopping check
#            if dev_accuracy > best_dev_accuracy:
#                best_dev_accuracy = dev_accuracy
#                no_improvement_count = 0
                # Save best model state
#                best_state = {
#                    'beta': self.router.beta,
#                    'optimal_states': self.optimal_states.copy() if hasattr(self, 'optimal_states') else None
#                }
#                logger.info(f"New best model saved with dev accuracy: {best_dev_accuracy:.2%}")
#            else:
#                no_improvement_count += 1
#                if no_improvement_count >= patience:
#                    logger.info(f"No improvement for {patience} epochs. Early stopping.")
#                    # Restore best model
#                    if best_state:
#                        self.router.beta = best_state['beta']
##                        if best_state['optimal_states']:
#                            self.optimal_states = best_state['optimal_states']
#                        logger.info(f"Restored best model with beta={self.router.beta:.3f}")
#                    break
        
#        return training_history
    
    def evaluate(self, test_data):
        """Evaluate on test data after training"""
        self.set_training_mode(False)
        test_results = self._process_dataset(test_data, update_router=False)
        test_accuracy = sum(r['is_correct'] for r in test_results) / len(test_data['questions'])
        
        # Calculate convergence stats
        convergence_rate = sum(1 for r in test_results if r.get('is_converged', False)) / len(test_results)
        avg_lyapunov = np.mean([r.get('lyapunov_value', 0) for r in test_results])
        
        logger.info(f"\nTest accuracy: {test_accuracy:.2%}")
        logger.info(f"Convergence rate: {convergence_rate:.2%}")
        logger.info(f"Average Lyapunov value: {avg_lyapunov:.4f}")
        
        self.test_results = test_results
        return test_results, test_accuracy
        
    def visualize_training_progress(self, output_dir="results/squad"):
        """Visualize training progress and convergence"""
        os.makedirs(output_dir, exist_ok=True)
        
        plt.figure(figsize=(15, 10))
        
        # Plot 1: Accuracy progression
        plt.subplot(2, 2, 1)
        epochs = range(1, len(self.router.accuracy_history) + 1)
        plt.plot(self.router.accuracy_history, 'b.-')
        plt.title('SQuAD Training Accuracy Progression')
        plt.xlabel('Training Example')
        plt.ylabel('Accuracy (1=Correct, 0=Wrong)')
        plt.grid(True)
        
        # Plot 2: Router beta evolution
        plt.subplot(2, 2, 2)
        plt.plot(self.router.beta_history, 'r-')
        plt.title('Router Beta Parameter Evolution')
        plt.xlabel('Update Step')
        plt.ylabel('Beta Value')
        plt.grid(True)
        
        # Plot 3: Lyapunov function values
        plt.subplot(2, 2, 3)
        plt.plot(self.convergence_analyzer.lyapunov_history)
        plt.title('Lyapunov Function Values')
        plt.xlabel('Iteration')
        plt.ylabel('Lyapunov Value')
        plt.yscale('log')
        plt.grid(True)
        
        # Plot 4: Routing weights by module
        plt.subplot(2, 2, 4)
        if self.router.routing_history:
            p_weights = [h['perception'] for h in self.router.routing_history[-10:]]
            m_weights = [h['memory'] for h in self.router.routing_history[-10:]]
            r_weights = [h['reasoning'] for h in self.router.routing_history[-10:]]
            
            indices = range(len(p_weights))
            plt.plot(indices, p_weights, 'b-o', label='Perception')
            plt.plot(indices, m_weights, 'g-o', label='Memory')
            plt.plot(indices, r_weights, 'r-o', label='Reasoning')
            plt.title('Recent Routing Weights')
            plt.xlabel('Recent Problems')
            plt.ylabel('Weight')
            plt.legend()
            plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'training_progress_squad.png'), dpi=300)
        plt.close()
        
        logger.info(f"Saved training visualization to {output_dir}/training_progress_squad.png")

    def _check_answer_correctness(self, predicted_answer, expected_answer):
        """
        Use the imported check_answer_correctness function
        """
        from squad_dataset import check_answer_correctness
        return check_answer_correctness(predicted_answer, expected_answer)

# dynamic_router_StrategyQA.py
import json
import numpy as np
import logging
import torch
from scipy.special import softmax
from information_bottleneck_StrategyQA import MutualInformationEstimator, MineEstimator

logger = logging.getLogger(__name__)

class DynamicRouter:
    def __init__(self, bedrock_runtime, beta=1.0, l1_strength=0.0, entropy_strength=0.0):
        self.bedrock_runtime = bedrock_runtime
        self.beta = beta  # Temperature parameter for routing
        self.l1_strength = l1_strength  # L1 regularization strength
        self.entropy_strength = entropy_strength  # Entropy regularization strength
        self.routing_history = []
        self.mi_values_history = {}
        
        # Initialize MI estimators
        self.mi_estimator = MutualInformationEstimator()
        self.mine_estimator = MineEstimator(input_dim=256, hidden_dim=128)
        
        # Cache for embeddings to avoid recomputation
        self.embedding_cache = {}
    
    def estimate_mutual_information(self, x_text, y_text):
        """Estimate mutual information using embeddings similarity"""
        # Convert dict to text if needed
        if isinstance(x_text, dict):
            x_text = json.dumps(x_text)
        if isinstance(y_text, dict):
            y_text = json.dumps(y_text)
            
        try:
            # Generate embeddings (or use cached versions)
            x_embedding = self._get_embedding(x_text)
            y_embedding = self._get_embedding(y_text)
            
            # Calculate MI using multiple methods and average
            mi_methods = {
                'cosine': self._cosine_mi(x_embedding, y_embedding),
                'kl': self._kl_mi(x_embedding, y_embedding),
                'mine': self._mine_mi(x_embedding, y_embedding)
            }
            
            # Weight the methods (can adjust these weights)
            weights = {'cosine': 0.4, 'kl': 0.3, 'mine': 0.3}
            mi_estimate = sum(weights[m] * mi_methods[m] for m in mi_methods)
            
            return mi_estimate
            
        except Exception as e:
            logger.error(f"MI estimation error: {e}")
            return 0.5  # Default value
    
    def _get_embedding(self, text):
        """Get embedding for text, using cache if available - WITH CHUNKING"""
        # Use a hash of the text as cache key
        cache_key = hash(text[:500])  # Use first 500 chars for hash
        
        if cache_key in self.embedding_cache:
            return self.embedding_cache[cache_key]
            
        try:
            # APPLY THE SAME CHUNKING LOGIC AS BaseAgent
            if len(text) > 1900:  # Safe limit
                # Split text into chunks
                chunks = [text[i:i+1900] for i in range(0, len(text), 1900)]
                
                chunk_embeddings = []
                for i, chunk in enumerate(chunks[:5]):  # Only use first 5 chunks
                    try:
                        response = self.bedrock_runtime.invoke_model(
                            modelId="cohere.embed-multilingual-v3",
                            contentType="application/json",
                            body=json.dumps({"texts": [chunk], "input_type": "search_document"})
                        )
                        chunk_embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
                        chunk_embeddings.append(chunk_embedding)
                    except Exception as chunk_e:
                        logger.error(f"Error embedding chunk {i}: {chunk_e}")
                
                # Average the chunk embeddings
                if chunk_embeddings:
                    embedding = np.mean(chunk_embeddings, axis=0)
                else:
                    embedding = np.random.rand(1024)  # Fallback
            else:
                # Standard processing for shorter text
                response = self.bedrock_runtime.invoke_model(
                    modelId="cohere.embed-multilingual-v3",
                    contentType="application/json",
                    body=json.dumps({"texts": [text], "input_type": "search_document"})
                )
                embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
            
            # Cache the embedding
            self.embedding_cache[cache_key] = embedding
            
            # Limit cache size
            if len(self.embedding_cache) > 1000:
                # Remove a random key
                random_key = next(iter(self.embedding_cache))
                del self.embedding_cache[random_key]
                
            return embedding
            
        except Exception as e:
            logger.error(f"Embedding generation error: {e}")
            return np.random.rand(1024)  # Return fallback
    
    def _cosine_mi(self, x_embedding, y_embedding):
        """Estimate MI using cosine similarity as proxy"""
        x_norm = np.linalg.norm(x_embedding)
        y_norm = np.linalg.norm(y_embedding)
        
        if x_norm == 0 or y_norm == 0:
            return 0.0
                
        similarity = np.dot(x_embedding, y_embedding) / (x_norm * y_norm)
        
        # Transform to [0, 1] range
        return (similarity + 1) / 2
    
    def _kl_mi(self, x_embedding, y_embedding):
        """Estimate MI using KL divergence approach"""
        # Create probability distributions from embeddings
        # We use softmax to convert to probability distributions
        x_prob = softmax(x_embedding[:100])  # Use first 100 dimensions for efficiency
        y_prob = softmax(y_embedding[:100])
        
        # Estimate mutual information using JS divergence
        # MI ≈ 1 - JS(P || Q)
        js_div = self.mi_estimator.jensen_shannon_divergence(x_prob, y_prob)
        
        # Convert to MI estimate in [0,1]
        return 1.0 - js_div
    
    def _mine_mi(self, x_embedding, y_embedding):
        try:
            # Limit dimensions for compatibility
            min_dim = min(self.mine_estimator.input_dim, len(x_embedding), len(y_embedding))
            
            # Convert to tensors
            x = torch.tensor(x_embedding[:min_dim], dtype=torch.float32).unsqueeze(0)
            y = torch.tensor(y_embedding[:min_dim], dtype=torch.float32).unsqueeze(0)
            
            # Get MI estimate from MINE
            with torch.no_grad():
                mi_est = self.mine_estimator.estimate(x, y)
            
            # Normalize to [0,1]
            return min(1.0, max(0.0, mi_est))
        except Exception as e:
            logger.error(f"MINE MI estimation error: {e}, falling back to cosine similarity")
            # Fallback to cosine similarity
            return self._cosine_mi(x_embedding, y_embedding)
    
    def compute_routing_matrix(self, outputs):
        """Compute routing weights based on mutual information between module outputs"""
        perception_out = outputs['perception']
        memory_out = outputs['memory']
        reasoning_out = outputs['reasoning']
        
        # Estimate pairwise mutual information
        mi_pm = self.estimate_mutual_information(perception_out, memory_out)
        mi_pr = self.estimate_mutual_information(perception_out, reasoning_out)
        mi_mr = self.estimate_mutual_information(memory_out, reasoning_out)
        
        # Apply temperature scaling
        scaled_mi_pm = self.beta * mi_pm
        scaled_mi_pr = self.beta * mi_pr
        scaled_mi_mr = self.beta * mi_mr
        
        # Apply L1 regularization if specified
        if self.l1_strength > 0:
            scaled_mi_pm = max(0, scaled_mi_pm - self.l1_strength)
            scaled_mi_pr = max(0, scaled_mi_pr - self.l1_strength)
            scaled_mi_mr = max(0, scaled_mi_mr - self.l1_strength)
        
        # Compute executive weights using softmax normalization
        sum_mi = scaled_mi_pm + scaled_mi_pr + scaled_mi_mr
        if sum_mi > 0:
            p_weight = (scaled_mi_pm + scaled_mi_pr) / (2 * sum_mi)
            m_weight = (scaled_mi_pm + scaled_mi_mr) / (2 * sum_mi)
            r_weight = (scaled_mi_pr + scaled_mi_mr) / (2 * sum_mi)
        else:
            # Default equal weights
            p_weight = m_weight = r_weight = 1/3
        
        # Apply entropy regularization if specified
        if self.entropy_strength > 0:
            weights = np.array([p_weight, m_weight, r_weight])
            entropy = -np.sum(weights * np.log(weights + 1e-10))
            weights = weights * np.exp(self.entropy_strength * entropy)
            weights = weights / np.sum(weights)
            p_weight, m_weight, r_weight = weights
        
        # Normalize to ensure they sum to 1
        total = p_weight + m_weight + r_weight
        p_weight /= total
        m_weight /= total
        r_weight /= total
        
        # Create routing matrix
        routing_matrix = {
            'perception': p_weight,
            'memory': m_weight,
            'reasoning': r_weight,
        }
        
        # Record MI values
        mi_values = {
            'perception_memory': mi_pm,
            'perception_reasoning': mi_pr,
            'memory_reasoning': mi_mr
        }
        
        # Store in history
        self.routing_history.append(routing_matrix)
        self.mi_values_history[len(self.routing_history)] = mi_values
        
        return routing_matrix
    
    def estimate_conditional_mutual_information(self, x_text, y_text, z_text):
        """
        Estimate conditional mutual information I(X;Y|Z) using the chain rule:
        I(X;Y|Z) = I(X;Y,Z) - I(X;Z)
        """
        # Convert dict to text if needed
        if isinstance(x_text, dict):
            x_text = json.dumps(x_text)
        if isinstance(y_text, dict):
            y_text = json.dumps(y_text)
        if isinstance(z_text, dict):
            z_text = json.dumps(z_text)
            
        try:
            # Generate embeddings
            x_embedding = self._get_embedding(x_text)
            y_embedding = self._get_embedding(y_text)
            z_embedding = self._get_embedding(z_text)
            
            # Combine y and z embeddings for joint distribution
            yz_embedding = np.concatenate([y_embedding[:len(y_embedding)//2], z_embedding[:len(z_embedding)//2]])
            
            # Calculate I(X;Y,Z) - mutual information between X and the joint (Y,Z)
            mi_x_yz = self.estimate_mutual_information(x_embedding, yz_embedding)
            
            # Calculate I(X;Z) - mutual information between X and Z
            mi_x_z = self.estimate_mutual_information(x_embedding, z_embedding)
            
            # Conditional mutual information I(X;Y|Z) = I(X;Y,Z) - I(X;Z)
            cond_mi = max(0, mi_x_yz - mi_x_z)  # Ensure non-negative
            
            return cond_mi
            
        except Exception as e:
            logger.error(f"Conditional MI estimation error: {e}")
            return 0.3  # Default value

    def compute_routing_matrix_advanced(self, outputs):
        """
        Compute routing weights based on conditional mutual information between module outputs:
        ω_ij(t) = exp(β · I(S_i(t); S_j(t) | S_{-ij}(t))) / ∑_{k,l} exp(β · I(S_k(t); S_l(t) | S_{-kl}(t)))
        """
        perception_out = outputs['perception']
        memory_out = outputs['memory']
        reasoning_out = outputs['reasoning']
        
        # Get other outputs for conditioning
        other_modules = {'perception': memory_out + " " + reasoning_out,
                        'memory': perception_out + " " + reasoning_out,
                        'reasoning': perception_out + " " + memory_out}
        
        # Estimate conditional mutual information
        cmi_pm = self.estimate_conditional_mutual_information(
            perception_out, memory_out, other_modules['perception'])
        cmi_pr = self.estimate_conditional_mutual_information(
            perception_out, reasoning_out, other_modules['perception'])
        cmi_mr = self.estimate_conditional_mutual_information(
            memory_out, reasoning_out, other_modules['memory'])
        
        # Apply temperature scaling
        scaled_cmi_pm = self.beta * cmi_pm
        scaled_cmi_pr = self.beta * cmi_pr
        scaled_cmi_mr = self.beta * cmi_mr
        
        # Compute routing matrix using softmax normalization
        mi_values = np.array([scaled_cmi_pm, scaled_cmi_pr, scaled_cmi_mr])
        routing_values = softmax(mi_values)
        
        # Create routing matrix
        p_weight = routing_values[0]
        r_weight = routing_values[1]
        m_weight = routing_values[2]
        
        # Normalize to ensure they sum to 1
        total = p_weight + m_weight + r_weight
        p_weight /= total
        m_weight /= total
        r_weight /= total
        
        # Create routing matrix
        routing_matrix = {
            'perception': p_weight,
            'memory': m_weight,
            'reasoning': r_weight,
        }
        
        # Record MI values
        mi_values = {
            'perception_memory_conditional': cmi_pm,
            'perception_reasoning_conditional': cmi_pr,
            'memory_reasoning_conditional': cmi_mr
        }
        
        # Store in history
        self.routing_history.append(routing_matrix)
        self.mi_values_history[len(self.routing_history)] = mi_values
        
        return routing_matrix


class ModularRouter:
    """Advanced router supporting multi-step information flow"""
    def __init__(self, bedrock_runtime, modules=None, beta=1.0):
        self.bedrock_runtime = bedrock_runtime
        self.beta = beta
        self.modules = modules if modules else ['perception', 'memory', 'reasoning', 'executive']
        self.module_count = len(self.modules)
        
        # Full routing matrix (NxN)
        self.routing_matrix = np.ones((self.module_count, self.module_count)) / self.module_count
        
        # MI estimator
        self.mi_estimator = MutualInformationEstimator()
        
        # History
        self.routing_history = []
        
    def update_routing(self, module_states):
        """Update routing matrix based on module states"""
        # Compute pairwise MI for all modules
        mi_matrix = np.zeros((self.module_count, self.module_count))
        
        for i in range(self.module_count):
            for j in range(self.module_count):
                if i != j:
                    source = self.modules[i]
                    target = self.modules[j]
                    if source in module_states and target in module_states:
                        mi = self.mi_estimator.estimate(
                            module_states[source], 
                            module_states[target]
                        )
                        mi_matrix[i, j] = mi
        
        # Apply temperature scaling
        mi_matrix = mi_matrix * self.beta
        
        # Normalize each row to create routing probabilities
        for i in range(self.module_count):
            row_sum = np.sum(mi_matrix[i])
            if row_sum > 0:
                mi_matrix[i] = mi_matrix[i] / row_sum
        
        # Update routing matrix with exponential smoothing
        alpha = 0.3  # Smoothing factor
        self.routing_matrix = (1 - alpha) * self.routing_matrix + alpha * mi_matrix
        
        # Store history
        self.routing_history.append(self.routing_matrix.copy())
        
        return self.routing_matrix
    
    def get_routing_strength(self, source, target):
        """Get routing strength between named modules"""
        if source in self.modules and target in self.modules:
            i = self.modules.index(source)
            j = self.modules.index(target)
            return self.routing_matrix[i, j]
        return 0.0
    
class InformationTransferOperator:
    """
    Implements the Information Transfer Operator between models:
    T_{i→j}(S_i, S_j) = φ_j(∫_{M_i} K(S_i, m) · ψ_{i→j}(m) dm)
    """
    def __init__(self, kernel_bandwidth=0.1, n_samples=100):
        self.sigma = kernel_bandwidth
        self.n_samples = n_samples
    
    def kernel_function(self, s_i, m):
        """
        Gaussian RBF kernel: K(S_i, m) = exp(-||S_i - m||² / 2σ²)
        """
        return np.exp(-np.sum((s_i - m) ** 2) / (2 * self.sigma ** 2))
    
    def geometry_aware_kernel(self, s_i, m, metric_tensor):
        """
        Geometry-aware kernel: K(S_i, m) = exp(-(m-S_i)^T G_i (m-S_i) / 2σ²)
        """
        diff = m - s_i
        return np.exp(-0.5 * diff.T @ metric_tensor @ diff / self.sigma ** 2)
    
    def transfer_mapping(self, m, source_dim, target_dim):
        """
        Transfer mapping function ψ_{i→j}(m) that maps from manifold i to j
        """
        # Simple linear projection
        if source_dim > target_dim:
            # Downsampling projection
            return m[:target_dim]
        else:
            # Upsampling projection (with zeros)
            result = np.zeros(target_dim)
            result[:source_dim] = m
            return result
    
    def transfer_operator(self, source_state, target_state, source_manifold=None, target_manifold=None):
        """
        Full implementation of the Information Transfer Operator
        """
        source_dim = len(source_state)
        target_dim = len(target_state)
        
        # Generate proposal distribution for importance sampling
        samples = source_state + self.sigma * np.random.randn(self.n_samples, source_dim)
        
        # Compute kernel values and transfer mappings for each sample
        weighted_transfers = np.zeros(target_dim)
        normalization = 0
        
        for sample in samples:
            # Compute kernel value
            kernel_val = self.kernel_function(source_state, sample)
            
            # Compute the transfer mapping
            transferred = self.transfer_mapping(sample, source_dim, target_dim)
            
            # Weight by the kernel value
            weighted_transfers += kernel_val * transferred
            normalization += kernel_val
        
        # Normalize
        if normalization > 0:
            weighted_transfers /= normalization
        
        # Apply the integration function (weighted combination)
        alpha = 0.7  # Weight for target state
        result = alpha * target_state + (1 - alpha) * weighted_transfers
        
        return result


# dynamic_router_trainable_squad.py
from dynamic_router_StrategyQA import DynamicRouter
import numpy as np
import logging

logger = logging.getLogger(__name__)

class TrainableDynamicRouter_SQuAD(DynamicRouter):
    """Enhanced router with training capabilities for SQuAD"""
    
    def __init__(self, bedrock_runtime, beta=1.0, learning_rate=0.05, beta_min=0.8, beta_max=4):
        super().__init__(bedrock_runtime, beta)
        self.learning_rate = learning_rate  # Learning rate for updating beta
        self.beta_min = beta_min            # Minimum beta value
        self.beta_max = beta_max            # Maximum beta value
        self.accuracy_history = []
        self.beta_history = [beta]
        logger.info(f"Initialized TrainableDynamicRouter for SQuAD with beta={beta}")
    
    def update_parameters(self, correct):
        """Update router parameters based on correctness of last answer"""
        self.accuracy_history.append(1.0 if correct else 0.0)
        
        # Extract MI diversity to guide beta updates
        if len(self.mi_values_history) > 0:
            latest_mi = list(self.mi_values_history.values())[-1]
            mi_values = [latest_mi['perception_memory'], 
                        latest_mi['perception_reasoning'], 
                        latest_mi['memory_reasoning']]
            
            # Calculate MI diversity - how different are the mutual information values
            mi_std = np.std(mi_values)
            
            if not correct:
                # For reading comprehension, we want to adjust based on answer extraction
                if mi_std < 0.1:
                    # If MI values are too similar, increase beta to differentiate modules more
                    self.beta = min(self.beta_max, self.beta + self.learning_rate * 0.5)
                else:
                    # If MI values are already diverse, reduce beta slightly
                    self.beta = max(self.beta_min, self.beta - self.learning_rate * 0.5)
            else:
                # For correct answers, make smaller adjustments
                if mi_std < 0.05:
                    self.beta = min(self.beta_max, self.beta + 0.02)
                elif mi_std > 0.3:
                    self.beta = max(self.beta_min, self.beta - 0.02)
        
        self.beta_history.append(self.beta)
        logger.info(f"SQuAD Router: Updated beta to {self.beta:.3f} based on {'correct' if correct else 'incorrect'} result")



# evaluation_squad.py
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import os
from collections import defaultdict

class DNONEvaluator_SQuAD:
    """Evaluation framework for DNON system performance on SQuAD"""
    
    def __init__(self, save_dir="evaluation_results_squad"):
        self.save_dir = save_dir
        self.results = defaultdict(list)
        self.metrics = {}
        self.latency_stats = defaultdict(list)
        
        # Create directory for results
        os.makedirs(save_dir, exist_ok=True)
    
    def evaluate(self, dnon_system, test_data):
        """Evaluate DNON system on SQuAD questions"""
        all_results = []
        
        questions = test_data['questions']
        contexts = test_data['contexts']
        answers = test_data['answers']
        
        start_time = time.time()
        
        for i, (question, context, expected_answer) in enumerate(
                zip(questions, contexts, answers)):
            
            print(f"Evaluating question {i+1}/{len(questions)}")
            
            # Process through DNON system
            result = dnon_system.solve(question, context=context)
            
            # Record results
            self.results['questions'].append(question)
            self.results['contexts'].append(context)
            self.results['expected_answers'].append(expected_answer)
            self.results['final_answers'].append(result['final_answer'])
            
            # Check correctness
            from squad_dataset import check_answer_correctness
            is_correct = check_answer_correctness(result['final_answer'], expected_answer)
            self.results['is_correct'].append(is_correct)
            
            # Record latency
            for module in ['perception', 'memory', 'reasoning', 'executive']:
                self.latency_stats[module].append(result['timings'][module])
                
            all_results.append(result)
        
        total_time = time.time() - start_time
        
        # Save results
        with open(os.path.join(self.save_dir, 'evaluation_results.json'), 'w') as f:
            json.dump({
                'questions': self.results['questions'],
                'contexts': [c[:200] + "..." if len(c) > 200 else c for c in self.results['contexts']],
                'expected_answers': self.results['expected_answers'],
                'final_answers': self.results['final_answers'],
                'is_correct': self.results['is_correct'],
                'latency_stats': {k: v for k, v in self.latency_stats.items()},
                'total_time': total_time,
                'avg_time_per_question': total_time / len(questions)
            }, f, indent=2)
        
        return all_results
    
    def compute_metrics(self):
        """Compute evaluation metrics for SQuAD"""
        # Calculate accuracy
        correct_count = sum(self.results['is_correct'])
        total_count = len(self.results['is_correct'])
        accuracy = correct_count / total_count if total_count > 0 else 0
        
        # Calculate F1 score for partial matches
        f1_scores = []
        for pred, exp in zip(self.results['final_answers'], self.results['expected_answers']):
            f1 = self.compute_f1(pred, exp)
            f1_scores.append(f1)
        
        avg_f1 = np.mean(f1_scores) if f1_scores else 0
        
        # Calculate latency statistics
        latency_metrics = {}
        for module, latencies in self.latency_stats.items():
            latency_metrics[module] = {
                'mean': np.mean(latencies),
                'median': np.median(latencies),
                'std': np.std(latencies),
                'min': np.min(latencies),
                'max': np.max(latencies)
            }
        
        # Store and return metrics
        self.metrics = {
            'accuracy': accuracy,
            'avg_f1': avg_f1,
            'correct': correct_count,
            'total': total_count,
            'latency': latency_metrics
        }
        
        # Save metrics
        with open(os.path.join(self.save_dir, 'metrics.json'), 'w') as f:
            json.dump(self.metrics, f, indent=2)
        
        return self.metrics
    
    def compute_f1(self, prediction, ground_truth):
        """Compute F1 score for partial matches"""
        prediction_tokens = prediction.lower().split()
        ground_truth_tokens = ground_truth.lower().split()
        
        common = set(prediction_tokens) & set(ground_truth_tokens)
        
        # If either is empty, F1 = 0
        if len(prediction_tokens) == 0 or len(ground_truth_tokens) == 0:
            return 0
            
        precision = len(common) / len(prediction_tokens)
        recall = len(common) / len(ground_truth_tokens)
        
        if precision + recall == 0:
            return 0
            
        f1 = 2 * precision * recall / (precision + recall)
        return f1
    
    def visualize_results(self):
        """Visualize evaluation results for SQuAD"""
        if not self.metrics:
            self.compute_metrics()
        
        # Create figure
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Plot accuracy and F1
        metrics = [self.metrics['accuracy'], self.metrics['avg_f1']]
        labels = ['Exact Match', 'F1 Score']
        axes[0, 0].bar(labels, metrics, color=['blue', 'green'])
        axes[0, 0].set_ylim(0, 1)
        axes[0, 0].set_ylabel('Score')
        axes[0, 0].set_title(f"Accuracy: {self.metrics['accuracy']:.2f}, F1: {self.metrics['avg_f1']:.2f}")
        
        # Plot latency by module
        modules = list(self.latency_stats.keys())
        mean_latencies = [self.metrics['latency'][m]['mean'] for m in modules]
        axes[0, 1].bar(modules, mean_latencies)
        axes[0, 1].set_ylabel('Latency (s)')
        axes[0, 1].set_title('Average Module Latency')
        axes[0, 1].tick_params(axis='x', rotation=45)
        
        # Plot latency distribution
        for i, module in enumerate(modules):
            axes[1, 0].hist(self.latency_stats[module], alpha=0.5, label=module, bins=20)
        axes[1, 0].set_xlabel('Latency (s)')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Latency Distribution by Module')
        axes[1, 0].legend()
        
        # Plot cumulative latency
        cum_latencies = np.zeros(len(self.results['questions']))
        for module in modules:
            cum_latencies += np.array(self.latency_stats[module])
        axes[1, 1].plot(cum_latencies, marker='o')
        axes[1, 1].set_xlabel('Question Index')
        axes[1, 1].set_ylabel('Total Latency (s)')
        axes[1, 1].set_title('Total Processing Time per Question')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'evaluation_results.png'), dpi=300)
        plt.show()

# manifold_representation_StrategyQA.py
import numpy as np
from numpy.linalg import norm, inv
import torch

class EuclideanEmbedding:
    """Represents a point in Euclidean space as an approximation of a manifold point"""
    def __init__(self, coordinates):
        self.coordinates = np.array(coordinates)
        self.dim = len(coordinates)
    
    def copy(self):
        """Create a copy of this embedding"""
        return EuclideanEmbedding(self.coordinates.copy())
    
    def distance(self, other):
        """Euclidean distance to another point"""
        return norm(self.coordinates - other.coordinates)
    
    def to_tensor(self):
        """Convert to PyTorch tensor"""
        return torch.tensor(self.coordinates, dtype=torch.float32)

class RiemannianManifold:
    """Implementation of a Riemannian manifold with metric tensor"""
    def __init__(self, dim=256, curvature=0.01):
        self.dim = dim
        self.curvature = curvature
        # Identity metric tensor (can be modified for non-Euclidean geometry)
        self.metric_tensor = np.eye(dim)
        
    def metric(self, x):
        """Return the metric tensor at point x"""
        # In a more sophisticated implementation, this would return a 
        # position-dependent metric tensor G(x)
        return self.metric_tensor
    
    def distance(self, x, y):
        """Compute geodesic distance between two points"""
        # For general manifold, this would compute the length of geodesic
        # In this simple case, we use Mahalanobis distance as approximation
        diff = x.coordinates - y.coordinates
        return np.sqrt(diff @ self.metric_tensor @ diff)
    
    def project(self, vector):
        """Project a raw vector onto the manifold"""
        # In a real implementation, this would project onto the manifold
        # Here we just wrap it in our embedding class
        return EuclideanEmbedding(vector)
    
    def geodesic(self, start, end, step_size=1.0):
        """Move along geodesic from start toward end with given step size"""
        # In a real implementation, this would follow geodesic curve
        # Here we use linear interpolation as an approximation
        diff = end.coordinates - start.coordinates
        # Apply metric to get direction in tangent space
        direction = diff @ self.metric_tensor
        # Normalize and scale by step size
        if norm(direction) > 0:
            direction = direction / norm(direction) * step_size
        
        # Create new point along the geodesic
        new_coords = start.coordinates + direction
        return EuclideanEmbedding(new_coords)
    
    def parallel_transport(self, vector, from_point, to_point):
        """Transport a vector along geodesic from one point to another"""
        # In general Riemannian geometry, parallel transport preserves inner products
        # This is a simplified implementation
        return vector  # Identity transport for simplicity
    
    
    def exponential_map(self, point, tangent_vector):
        """Map from tangent space to manifold (exponential map)"""
        # In a real implementation, this would follow geodesic in direction of tangent vector
        new_coords = point.coordinates + tangent_vector
        return EuclideanEmbedding(new_coords)
    
    def logarithmic_map(self, point, target):
        """Map from manifold to tangent space (logarithmic map)"""
        # In a real implementation, this would compute the tangent vector
        # that points along the geodesic from point to target
        return target.coordinates - point.coordinates
    
    def cognitive_distance_full(self, x, y):
        """
        Compute the full cognitive distance between points x and y:
        D(S_i, S_j) = sqrt(∫_0^1 γ̇(t)^T G_γ(t) γ̇(t) dt)
        """
        # Define parameterized geodesic path γ(t)
        def geodesic_path(t):
            return (1-t) * x.coordinates + t * y.coordinates
        
        # Define tangent vector (velocity) along the path
        def velocity(t):
            return y.coordinates - x.coordinates
        
        # Define the integrand
        def integrand(t):
            point = geodesic_path(t)
            vel = velocity(t)
            G_t = self.metric_tensor_at(point)
            return np.sqrt(vel.T @ G_t @ vel)
        
        # Numerical integration using Simpson's rule
        from scipy import integrate
        distance, _ = integrate.quad(integrand, 0, 1)
        
        return distance

    def metric_tensor_at(self, point):
        """
        Compute the metric tensor G(x) at a specific point, allowing for position-dependent metrics
        """
        # For points passed as coordinates or as EuclideanEmbedding objects
        if isinstance(point, np.ndarray):
            point_norm = np.linalg.norm(point)
        else:
            point_norm = np.linalg.norm(point.coordinates)
            
        # Scale metric tensor based on curvature and distance from origin
        scale_factor = 1.0 / (1.0 + self.curvature * point_norm**2)
        return scale_factor * self.metric_tensor
    
    def exponential_map_enhanced(self, point, tangent_vector, epsilon=1e-6):
        """
        Enhanced exponential map that follows the geodesic in the direction of the tangent vector
        Exp_p(v) maps from the tangent space at p to the manifold
        """
        # Normalize the tangent vector and get its magnitude
        tangent_norm = np.linalg.norm(tangent_vector)
        if tangent_norm < epsilon:
            return EuclideanEmbedding(point.coordinates.copy())
        
        # In a curved space, the exponential map would involve solving geodesic equations
        # For simplicity, we use a scaled direction based on the metric
        G = self.metric_tensor_at(point.coordinates)
        scaled_tangent = np.linalg.solve(G + epsilon * np.eye(len(G)), tangent_vector)
        
        # Move along the geodesic
        new_coords = point.coordinates + tangent_norm * scaled_tangent
        return EuclideanEmbedding(new_coords)
    
    def logarithmic_map_enhanced(self, point, target, epsilon=1e-6):
        """
        Enhanced logarithmic map that computes the initial velocity of the geodesic
        Log_p(q) maps from the manifold to the tangent space at p
        """
        # Compute the difference vector
        diff = target.coordinates - point.coordinates
        
        # In a curved space, the logarithmic map would involve solving inverse geodesic equations
        # Scale by the metric tensor
        G = self.metric_tensor_at(point.coordinates)
        log_vector = G @ diff
        
        # Scale to match the true distance
        true_distance = self.cognitive_distance(point, target)
        norm = np.linalg.norm(log_vector) + epsilon
        scaled_vector = log_vector * (true_distance / norm)
        
        return scaled_vector

class HyperbolicManifold(RiemannianManifold):
    """Hyperbolic manifold implementation (Poincaré disk model)"""
    def __init__(self, dim=2, curvature=-1):
        super().__init__(dim, curvature)
        
    def metric(self, x):
        """Hyperbolic metric tensor at point x"""
        # Poincaré disk model metric tensor
        r2 = np.sum(x.coordinates**2)
        if r2 >= 1.0:  # Ensure point is inside the Poincaré disk
            r2 = 0.99
        scale = 4 / ((1 - r2)**2)
        return scale * np.eye(self.dim)
    
    def distance(self, x, y):
        """Hyperbolic distance in the Poincaré disk model"""
        x_norm = np.sum(x.coordinates**2)
        y_norm = np.sum(y.coordinates**2)
        dot_product = np.sum(x.coordinates * y.coordinates)
        
        # Formula for hyperbolic distance in Poincaré disk
        numerator = 2 * norm(x.coordinates - y.coordinates)**2
        denominator = (1 - x_norm) * (1 - y_norm)
        
        # Avoid division by zero or negative values
        if denominator <= 0:
            return 100.0  # Large distance for points at/beyond boundary
            
        return np.arccosh(1 + numerator / denominator)

class ManifoldFactory:
    """Factory class to create different types of manifolds"""
    @staticmethod
    def create(manifold_type, **kwargs):
        if manifold_type.lower() == 'euclidean':
            return RiemannianManifold(**kwargs)
        elif manifold_type.lower() == 'hyperbolic':
            return HyperbolicManifold(**kwargs)
        else:
            raise ValueError(f"Unknown manifold type: {manifold_type}")
        

class ConnectionMetrics:
    """
    Implements cognitive distance metrics between different manifolds
    D(S_i, S_j) = sqrt(sum_k g_k^ij · (f_i^k(S_i) - f_j^k(S_j))^2)
    """
    def __init__(self, manifolds):
        """
        Initialize with a list of manifolds
        """
        self.manifolds = manifolds
    
    def cognitive_distance(self, point_i, point_j, manifold_i, manifold_j):
        """
        Compute the cognitive distance between points on different manifolds
        """
        # Get the projection functions for each manifold
        proj_i = self._get_projection(manifold_i, len(point_i.coordinates))
        proj_j = self._get_projection(manifold_j, len(point_j.coordinates))
        
        # Get the metric tensor between manifolds (simplified)
        g_ij = self._get_metric_tensor(manifold_i, manifold_j)
        
        # Compute projections
        proj_i_coords = proj_i(point_i.coordinates)
        proj_j_coords = proj_j(point_j.coordinates)
        
        # Ensure same dimensionality by padding if necessary
        max_dim = max(len(proj_i_coords), len(proj_j_coords))
        if len(proj_i_coords) < max_dim:
            proj_i_coords = np.pad(proj_i_coords, (0, max_dim - len(proj_i_coords)))
        if len(proj_j_coords) < max_dim:
            proj_j_coords = np.pad(proj_j_coords, (0, max_dim - len(proj_j_coords)))
        
        # Compute the squared differences
        squared_diffs = (proj_i_coords - proj_j_coords) ** 2
        
        # Apply the metric tensor weights
        weighted_diffs = g_ij[:max_dim] * squared_diffs
        
        # Return the cognitive distance
        return np.sqrt(np.sum(weighted_diffs))
    
    def _get_projection(self, manifold, dim):
        """Get a projection function for the given manifold"""
        # This is a simplified projection function that extracts the most important dimensions
        def proj(coords):
            if len(coords) <= dim:
                return coords
            # Use SVD for dimensionality reduction
            from scipy.linalg import svd
            U, s, _ = svd(coords.reshape(-1, 1), full_matrices=False)
            return U[:dim, 0] * s[0]
        return proj
    
    def _get_metric_tensor(self, manifold_i, manifold_j):
        """
        Get the metric tensor between two manifolds
        In the theoretical framework, this is g_k^ij
        """
        max_dim = max(manifold_i.dim, manifold_j.dim)
        # Create a diagonal metric tensor with positive values
        return np.ones(max_dim) + 0.1 * np.random.rand(max_dim)

# memory_dynamics_StrategyQA.py
import numpy as np
import time
import heapq
from collections import deque

class ShortTermMemory:
    """Short-term memory with temporal decay"""
    def __init__(self, capacity=20, decay_rate=0.1, dim=256):
        self.capacity = capacity
        self.decay_rate = decay_rate
        self.items = deque(maxlen=capacity)  # (text, embedding, timestamp, activation)
        self.dim = dim
        
    def add(self, text, embedding):
        """Add a new item to short-term memory"""
        timestamp = time.time()
        activation = 1.0  # Full activation for new items
        self.items.append((text, embedding, timestamp, activation))
        
    def retrieve(self, query_embedding, top_k=3):
        """Retrieve most relevant items considering decay"""
        now = time.time()
        
        # Calculate similarity and apply temporal decay
        scored_items = []
        for text, emb, timestamp, activation in self.items:
            # Compute similarity
            similarity = self.cosine_similarity(emb, query_embedding)
            
            # Apply temporal decay based on time elapsed
            time_elapsed = now - timestamp
            decay_factor = np.exp(-self.decay_rate * time_elapsed)
            
            # Final score combines similarity and decay
            score = similarity * decay_factor * activation
            scored_items.append((score, text))
        
        # Get top-k items
        scored_items.sort(reverse=True)
        top_items = scored_items[:top_k]
        
        # Update activations of retrieved items (rehearsal)
        self._update_activations([text for _, text in top_items])
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def _update_activations(self, retrieved_texts):
        """Update activation levels based on retrieval (rehearsal)"""
        for i, (text, emb, timestamp, activation) in enumerate(self.items):
            if text in retrieved_texts:
                # Boost activation for retrieved items (rehearsal)
                new_activation = min(1.0, activation + 0.2)
                self.items[i] = (text, emb, timestamp, new_activation)
            else:
                # Natural decay for non-retrieved items
                new_activation = max(0.1, activation - 0.05)
                self.items[i] = (text, emb, timestamp, new_activation)
    
    def apply_decay(self):
        """Apply temporal decay to all items based on elapsed time"""
        now = time.time()
        decayed_items = []
        
        for text, emb, timestamp, activation in self.items:
            time_elapsed = now - timestamp
            decay_factor = np.exp(-self.decay_rate * time_elapsed)
            new_activation = activation * decay_factor
            
            # Only keep items with sufficient activation
            if new_activation > 0.1:
                decayed_items.append((text, emb, timestamp, new_activation))
        
        self.items = deque(decayed_items, maxlen=self.capacity)
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.items:
            return np.zeros(self.dim)
        
        # Average embeddings weighted by activation
        embeddings = np.array([emb * activation for _, emb, _, activation in self.items])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())  # or similarity[0] if you prefer the first element
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, perception_input, dt=0.1):
        """
        Update the short-term memory state using the differential equation:
        dS_M^ST(t)/dt = F_M^ST(S_M^ST(t), S_P(t)) - (1/tau_ST) * S_M^ST(t) + xi_M^ST(t)
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            self.stochastic_terms = 0.01 * np.random.randn(self.dim)
            
            # Define the internal dynamics function F_M^ST
            def F_M_ST(state, perception):
                # This function defines how perception influences STM
                interaction = 0.7 * state + 0.3 * perception
                return np.tanh(interaction)  # Nonlinearity for stability
            
            # Apply the STM differential equation
            decay_term = self.state / self.decay_rate  # Using decay_rate as tau_ST
            dynamics_term = F_M_ST(self.state, perception_input)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (dynamics_term - decay_term + self.stochastic_terms)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state

    def update_dynamics_enhanced(self, perception_input, dt=0.1):
        """
        Enhanced implementation of Short-Term Memory dynamics:
        dS_M^ST(t)/dt = F_M^ST(S_M^ST(t), S_P(t)) - (1/tau_ST) * S_M^ST(t) + xi_M^ST(t)
        
        With attention-based perception integration
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            xi_M_ST = 0.01 * np.random.randn(self.dim)
            
            # Define the internal dynamics function F_M^ST with attention
            def F_M_ST(state, perception):
                # Calculate attention-based integration using salience
                attention_weights = self._compute_attention_salience(state, perception)
                
                # Apply nonlinear transformation with attention weights
                interaction = 0.7 * state + attention_weights * 0.3 * perception
                return np.tanh(interaction)  # Nonlinearity for stability
            
            # Apply the STM differential equation
            decay_term = self.state / self.decay_rate  # Using decay_rate as tau_ST
            dynamics_term = F_M_ST(self.state, perception_input)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (dynamics_term - decay_term + xi_M_ST)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state

    def _compute_attention_salience(self, state, perception):
        """Compute attention weights based on information salience"""
        # Project vectors to a lower dimension for efficiency
        max_dim = min(100, min(len(state), len(perception)))
        proj_state = state[:max_dim]
        proj_perception = perception[:max_dim]
        
        # Normalize vectors
        state_norm = np.linalg.norm(proj_state)
        perception_norm = np.linalg.norm(proj_perception)
        
        if state_norm > 0 and perception_norm > 0:
            proj_state = proj_state / state_norm
            proj_perception = proj_perception / perception_norm
        
        # Compute attention score using cosine similarity
        similarity = np.dot(proj_state, proj_perception)
        
        # Calculate information novelty (orthogonal component magnitude)
        orthogonal_component = proj_perception - similarity * proj_state
        novelty = np.linalg.norm(orthogonal_component)
        
        # Compute final attention weight as combination of similarity and novelty
        attention_weight = 0.3 * (1 + similarity) + 0.7 * novelty  # Scale to [0, ~2]
        
        # Apply sigmoid to get final weight in [0, 1]
        return 1.0 / (1.0 + np.exp(-2 * attention_weight + 2))

class LongTermMemory:
    """Long-term memory with importance-based consolidation"""
    def __init__(self, capacity=100, consolidation_rate=0.01, dim=256):
        self.capacity = capacity
        self.consolidation_rate = consolidation_rate
        self.dim = dim
        self.items = []  # (text, embedding, importance, creation_time)
        
    def consolidate(self, stm, embedding, importance_score):
        """Consolidate item from STM to LTM based on importance"""
        # Only consolidate if importance exceeds threshold
        consolidation_threshold = 0.5
        if importance_score > consolidation_threshold:
            # Get relevant item from STM
            top_items = []
            for text, emb, _, activation in stm.items:
                # Ensure similarity is a scalar for sorting
                similarity_val = self.cosine_similarity(emb, embedding)
                if isinstance(similarity_val, np.ndarray):
                    similarity_val = float(similarity_val.mean())
                top_items.append((similarity_val, text, emb))
            
            if top_items:
                # Use a key function to ensure we're comparing scalars
                def get_sort_key(item):
                    val = item[0]
                    if isinstance(val, np.ndarray):
                        return float(np.mean(val))
                    return val
                
                top_items.sort(key=get_sort_key, reverse=True)
                _, text, emb = top_items[0]
                
                # Check if similar item already exists in LTM
                if not self._has_similar(emb):
                    # Add to LTM
                    creation_time = time.time()
                    self.items.append((text, emb, importance_score, creation_time))
                    
                    # Keep within capacity
                    if len(self.items) > self.capacity:
                        # Remove least important item
                        self.items.sort(key=lambda x: x[2])
                        self.items.pop(0)
    
    def retrieve(self, query_embedding, top_k=2):
        """Retrieve most relevant items from long-term memory"""
        # Calculate similarity
        scored_items = []
        for text, emb, importance, _ in self.items:
            # Compute similarity
            similarity = self.cosine_similarity(emb, query_embedding)
            
            # Final score combines similarity and importance
            score = similarity * (0.5 + 0.5 * importance)
            scored_items.append((score, text))
        
        # Get top-k items
        scored_items.sort(reverse=True)
        top_items = scored_items[:top_k]
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def _has_similar(self, embedding, similarity_threshold=0.9):
        """Check if a similar item already exists in memory"""
        for _, emb, _, _ in self.items:
            similarity = self.cosine_similarity(emb, embedding)
            if similarity > similarity_threshold:
                return True
        return False
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.items:
            return np.zeros(self.dim)
        
        # Average embeddings weighted by importance
        embeddings = np.array([emb * importance for _, emb, importance, _ in self.items])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, stm_state, dt=0.1):
        """
        Update the long-term memory state using the differential equation:
        dS_M^LT(t)/dt = eta_LT · Phi(S_M^ST(t), S_M^LT(t)) + xi_M^LT(t)
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            self.stochastic_terms = 0.005 * np.random.randn(self.dim)
            
            # Define the consolidation function Phi
            def Phi(stm, ltm):
                # Calculate importance-based memory consolidation
                importance = np.linalg.norm(stm - ltm) / (np.linalg.norm(stm) + 0.001)
                
                # Only consolidate if importance exceeds threshold
                threshold = 0.2
                if importance > threshold:
                    # Consolidation happens in the direction of STM, modulated by importance
                    return importance * (stm - ltm)
                else:
                    return np.zeros_like(ltm)
            
            # Apply the LTM differential equation
            consolidation_term = Phi(stm_state, self.state)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (self.consolidation_rate * consolidation_term + self.stochastic_terms)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state
    
    def update_dynamics_enhanced(self, stm_state, dt=0.1):
        """
        Enhanced implementation of Long-Term Memory consolidation:
        dS_M^LT(t)/dt = eta_LT · Phi(S_M^ST(t), S_M^LT(t)) + xi_M^LT(t)
        
        With adaptive importance threshold and emotional salience
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            xi_M_LT = 0.005 * np.random.randn(self.dim)
            
            # Define the consolidation function Phi with adaptive threshold
            def Phi(stm, ltm):
                # Calculate importance-based memory consolidation
                # Magnitude difference indicates potential information gain
                magnitude_importance = np.linalg.norm(stm - ltm) / (np.linalg.norm(stm) + 0.001)
                
                # Compute semantic similarity using cosine similarity
                stm_norm = np.linalg.norm(stm)
                ltm_norm = np.linalg.norm(ltm)
                if stm_norm > 0 and ltm_norm > 0:
                    semantic_similarity = np.dot(stm, ltm) / (stm_norm * ltm_norm)
                else:
                    semantic_similarity = 0
                
                # Compute information density using entropy approximation
                stm_normalized = stm / (np.sum(np.abs(stm)) + 1e-10)
                info_density = -np.sum(np.abs(stm_normalized) * np.log2(np.abs(stm_normalized) + 1e-10))
                
                # Combine metrics with learned weights
                importance = (0.4 * magnitude_importance + 
                            0.3 * (1 - semantic_similarity) + 
                            0.3 * min(1.0, info_density/10))
                
                # Apply adaptive threshold based on recent importance history
                threshold = self._compute_adaptive_threshold()
                
                # Only consolidate if importance exceeds threshold
                if importance > threshold:
                    # Consolidation happens in the direction of STM, modulated by importance
                    return importance * (stm - ltm)
                else:
                    return np.zeros_like(ltm)
            
            # Apply the LTM differential equation
            consolidation_term = Phi(stm_state, self.state)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (self.consolidation_rate * consolidation_term + xi_M_LT)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state

    def _compute_adaptive_threshold(self):
        """Compute adaptive importance threshold based on recent consolidations"""
        # Start with base threshold
        base_threshold = 0.2
        
        # If we have enough items, adjust based on their importance values
        if len(self.items) >= 5:
            recent_importance = [item[2] for item in self.items[-5:]]
            # Adjust threshold to be 80% of the mean importance
            adaptive_component = 0.8 * np.mean(recent_importance)
            return 0.7 * base_threshold + 0.3 * adaptive_component
        
        return base_threshold

class DeepSubconsciousMemory:
    """Deep subconscious memory with emotional and procedural patterns"""
    def __init__(self, access_coef=0.001, dim=256):
        self.access_coef = access_coef
        self.dim = dim
        self.patterns = []  # (pattern_embedding, emotional_weight, update_count)
        self.core_knowledge = [  # Built-in arithmetic knowledge
            "Addition combines quantities: a + b = total",
            "Subtraction removes quantities: a - b = remainder", 
            "Order of operations: parentheses, exponents, multiplication/division, addition/subtraction",
            "When adding or subtracting entities, track each separately then combine at the end",
            "Division distributes items equally: total ÷ groups = items per group"
        ]
        self.core_embeddings = None  # Will be initialized on first access
        
    def update(self, embedding, ltm_output, emotional_salience):
        """Update deep memory based on input and emotional salience"""
        # Only make updates at the slow rate determined by access coefficient
        if np.random.random() > self.access_coef:
            return
            
        # Find similar patterns
        similarities = []
        for i, (pattern, weight, count) in enumerate(self.patterns):
            similarity = self.cosine_similarity(pattern, embedding)
            similarities.append((similarity, i))
        
        # Sort by similarity
        similarities.sort(reverse=True)
        
        if similarities and similarities[0][0] > 0.8:
            # Update existing pattern
            _, idx = similarities[0]
            pattern, weight, count = self.patterns[idx]
            
            # Slow integration of new information
            update_rate = 0.1
            updated_pattern = (1 - update_rate) * pattern + update_rate * embedding
            updated_pattern = updated_pattern / np.linalg.norm(updated_pattern)
            
            # Update emotional weight and count
            new_weight = (weight * count + emotional_salience) / (count + 1)
            new_count = count + 1
            
            self.patterns[idx] = (updated_pattern, new_weight, new_count)
        else:
            # Add new pattern if emotional salience is high enough
            if emotional_salience > 0.3:
                self.patterns.append((embedding, emotional_salience, 1))
        
        # Limit number of patterns
        if len(self.patterns) > 50:
            # Remove patterns with lowest emotional weight
            self.patterns.sort(key=lambda x: x[1])
            self.patterns = self.patterns[1:]
    
    def retrieve(self, query_embedding, top_k=3):
        """Retrieve from deep subconscious memory"""
        # Initialize core embeddings if needed
        if self.core_embeddings is None:
            # In a real implementation, these would be actual embeddings
            # Here we just create random embeddings as placeholders
            self.core_embeddings = [np.random.rand(self.dim) for _ in self.core_knowledge]
        
        # Get relevant core knowledge
        core_scores = []
        for i, emb in enumerate(self.core_embeddings):
            similarity = self.cosine_similarity(emb, query_embedding)
            core_scores.append((similarity, self.core_knowledge[i]))
        
        # Get relevant patterns
        pattern_scores = []
        for pattern, weight, _ in self.patterns:
            # Compute similarity
            similarity = self.cosine_similarity(pattern, query_embedding)
            
            # Final score combines similarity and emotional weight
            score = similarity * weight
            pattern_scores.append((score, f"Pattern with emotional significance: {weight:.2f}"))
        
        # Combine and get top results
        all_scores = core_scores + pattern_scores
        all_scores.sort(reverse=True)
        top_items = all_scores[:top_k]
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.patterns:
            return np.zeros(self.dim)
        
        # Average pattern embeddings weighted by emotional weight
        embeddings = np.array([pattern * weight for pattern, weight, _ in self.patterns])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())  # or similarity[0] if you prefer the first element
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for deep subconscious dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.emotional_state = np.zeros(5)  # 5 basic emotions

    def update_dynamics(self, ltm_state, emotional_input, dt=None):
        """
        Update the deep subconscious memory using the equation:
        S_M^DS(t+1) = S_M^DS(t) + alpha_DS · Psi(S_M^LT(t), S_M^DS(t), S_E(t))
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        # Deep subconscious updates very slowly
        if dt > 1.0:  # Much longer time scale than STM/LTM
            # Update emotional state
            self.emotional_state = 0.9 * self.emotional_state + 0.1 * emotional_input
            
            # Define the deep integration function Psi
            def Psi(ltm, dsm, emotional):
                # Calculate emotional salience
                emotional_magnitude = np.linalg.norm(emotional)
                
                # Identify patterns that occur with high emotional salience
                emotional_patterns = ltm * emotional_magnitude
                
                # Slow integration of emotionally salient patterns
                return 0.1 * emotional_patterns - 0.01 * dsm
            
            # Apply the DSM update equation
            integration_term = Psi(ltm_state, self.state, self.emotional_state)
            
            # Update step
            self.previous_state = self.state.copy()
            self.state = self.state + self.access_coef * integration_term
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state

# information_geometry.py
import numpy as np
import torch
from numpy.linalg import norm, pinv
from scipy.linalg import sqrtm

class FisherInformationMatrix:
    """Fisher Information Matrix for natural gradient calculations"""
    def __init__(self, dim=256, regularization=0.1):
        self.dim = dim
        self.regularization = regularization
        self.matrix = np.eye(dim) * regularization  # Initialize with regularization
        self.samples = []
        self.max_samples = 100
    
    def update(self, samples):
        """Update the Fisher matrix using new samples"""
        # Convert samples to numpy arrays if needed
        if isinstance(samples[0], torch.Tensor):
            samples = [s.numpy() for s in samples]
        
        # Store samples, removing oldest if needed
        self.samples.extend(samples)
        if len(self.samples) > self.max_samples:
            self.samples = self.samples[-self.max_samples:]
        
        # Compute empirical Fisher matrix
        if len(self.samples) > 1:
            # Center the samples
            sample_array = np.array(self.samples)
            mean_sample = np.mean(sample_array, axis=0)
            centered = sample_array - mean_sample
            
            # Compute covariance
            cov = np.zeros((self.dim, self.dim))
            for sample in centered:
                # Outer product
                cov += np.outer(sample, sample)
            
            cov /= len(centered)
            
            # Update the Fisher matrix with regularization
            self.matrix = cov + np.eye(self.dim) * self.regularization
    
    def get_matrix(self):
        """Get the current Fisher matrix"""
        return self.matrix
    
    def inverse(self):
        """Get the inverse of the Fisher matrix for natural gradient"""
        # Use pseudo-inverse for numerical stability
        return pinv(self.matrix)
    
    def get_natural_gradient(self, gradient):
        """Convert Euclidean gradient to natural gradient"""
        return self.inverse() @ gradient
    
    def natural_metric_tensor(self, theta):
        """
        Compute the natural metric tensor G_θ at parameter θ
        """
        # In information geometry, the metric tensor is the Fisher information
        return self.matrix

    # Add to the NaturalGradient class (around line 75)

    def natural_gradient_dynamics(self, parameter, gradient, learning_rate=0.01):
        """
        Implement natural gradient dynamics:
        dΘ_ij/dt = -η · G_Θ^(-1)(Θ_ij) · ∇_Θ_ij L(S(t))
        """
        # Find the parameter name
        param_name = None
        for name, param in self.parameters.items():
            if np.array_equal(param, parameter):
                param_name = name
                break
        
        if param_name is None:
            raise ValueError("Parameter not found in the optimizer")
        
        # Get the Fisher matrix for this parameter
        fisher_matrix = self.fisher_matrices[param_name].get_matrix()
        
        # Compute the natural gradient
        natural_grad = np.linalg.solve(fisher_matrix + 1e-6 * np.eye(len(fisher_matrix)), gradient)
        
        # Update the parameter using natural gradient
        updated_param = parameter - learning_rate * natural_grad
        
        # Store the updated parameter
        self.parameters[param_name] = updated_param
        
        return updated_param

class NaturalGradient:
    """Natural gradient optimizer"""
    def __init__(self, parameters, learning_rate=0.01):
        self.parameters = parameters
        self.learning_rate = learning_rate
        self.fisher_matrices = {}
        
        # Initialize Fisher matrix for each parameter
        for name, param in self.parameters.items():
            self.fisher_matrices[name] = FisherInformationMatrix(dim=len(param))
    
    def update_fisher_matrices(self, parameter_samples):
        """Update Fisher matrices with new parameter samples"""
        for name, samples in parameter_samples.items():
            self.fisher_matrices[name].update(samples)
    
    def step(self, gradients):
        """Update parameters using natural gradient"""
        for name, param in self.parameters.items():
            if name in gradients:
                # Convert gradient to natural gradient
                natural_grad = self.fisher_matrices[name].get_natural_gradient(gradients[name])
                
                # Update parameter
                self.parameters[name] = param - self.learning_rate * natural_grad
        
        return self.parameters

class KullbackLeibler:
    """KL divergence for probability distributions"""
    @staticmethod
    def kl_divergence(p, q):
        """Calculate KL divergence between two distributions"""
        # Avoid division by zero
        q = np.maximum(q, 1e-10)
        p = np.maximum(p, 1e-10)
        
        # Normalize if not already probability distributions
        if abs(np.sum(p) - 1.0) > 1e-6:
            p = p / np.sum(p)
        
        if abs(np.sum(q) - 1.0) > 1e-6:
            q = q / np.sum(q)
        
        return np.sum(p * np.log(p / q))
    
    @staticmethod
    def js_divergence(p, q):
        """Calculate Jensen-Shannon divergence between two distributions"""
        m = 0.5 * (p + q)
        return 0.5 * KullbackLeibler.kl_divergence(p, m) + 0.5 * KullbackLeibler.kl_divergence(q, m)
    
    @staticmethod
    def wasserstein_distance(p, q):
        """Approximate 1D Wasserstein distance between two distributions"""
        # Sort the distributions
        p_sorted = np.sort(p)
        q_sorted = np.sort(q)
        
        # Calculate the distance
        return np.sum(np.abs(p_sorted - q_sorted))
    
class InformationGeometricOptimizer:
    """Implementation of information-geometric optimization methods"""
    
    def __init__(self, learning_rate=0.01, regularization=0.001):
        self.learning_rate = learning_rate
        self.regularization = regularization
        self.fisher_matrices = {}
        self.parameters = {}
    
    def register_parameter(self, name, parameter, dim=None):
        """Register a parameter for optimization"""
        self.parameters[name] = parameter
        if dim is None:
            dim = len(parameter) if hasattr(parameter, '__len__') else 1
        self.fisher_matrices[name] = FisherInformationMatrix(dim=dim)
    
    def update_fisher(self, name, samples):
        """Update Fisher information for a parameter"""
        if name in self.fisher_matrices:
            self.fisher_matrices[name].update(samples)
    
    def optimize_step(self, name, gradient):
        """Perform one step of information-geometric optimization"""
        if name not in self.parameters:
            raise ValueError(f"Parameter {name} not registered")
        
        # Get the Fisher matrix
        fisher = self.fisher_matrices[name].get_matrix()
        
        # Compute natural gradient
        natural_gradient = np.linalg.solve(fisher + 1e-6 * np.eye(len(fisher)), gradient)
        
        # Update parameter
        self.parameters[name] = self.parameters[name] - self.learning_rate * natural_gradient
        
        # Apply regularization
        self.parameters[name] = self.parameters[name] * (1 - self.learning_rate * self.regularization)
        
        return self.parameters[name]
    
    def compute_riemannian_distance(self, theta1, theta2, fisher=None):
        """
        Compute the Riemannian distance between two parameter vectors
        using the Fisher information metric
        """
        if fisher is None:
            # Use identity matrix as default metric
            fisher = np.eye(len(theta1))
        
        # Compute the geodesic distance in the Riemannian manifold
        diff = theta2 - theta1
        return np.sqrt(diff.T @ fisher @ diff)

# information_bottleneck_StrategyQA.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.special import softmax
from numpy.linalg import norm, pinv

class MutualInformationEstimator:
    """Class for estimating mutual information between variables"""
    
    def estimate(self, x, y, method='mine'):
        """Estimate mutual information between x and y"""
        if method == 'mine':
            return self.mine_estimate(x, y)
        elif method == 'kl':
            return self.kl_estimate(x, y)
        elif method == 'jensen_shannon':
            return 1.0 - self.jensen_shannon_divergence(x, y)
        else:
            return self.histogram_estimate(x, y)
    
    def mine_estimate(self, x, y):
        """Mutual Information Neural Estimation (simplified)"""
        # This is a simplified placeholder - a real implementation would
        # train a neural network to estimate mutual information
        
        # Convert inputs to numpy arrays
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
            
        # Flatten inputs
        x = np.array(x).flatten()[:100]  # Use first 100 dims for efficiency
        y = np.array(y).flatten()[:100]
        
        # Normalize
        x = (x - np.mean(x)) / (np.std(x) + 1e-8)
        y = (y - np.mean(y)) / (np.std(y) + 1e-8)
        
        # Compute correlation
        corr = np.corrcoef(x, y)[0, 1]
        
        # Transform correlation to MI estimate
        # For Gaussian variables: I(X;Y) = -0.5 * log(1 - corr²)
        mi_est = max(0, min(1, -0.5 * np.log(1 - min(corr**2, 0.99))))
        
        return mi_est
    
    def kl_estimate(self, x, y):
        """KL-based mutual information estimate"""
        # Convert to probability distributions using softmax
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
            
        # Flatten and use first 100 dimensions
        x = np.array(x).flatten()[:100]
        y = np.array(y).flatten()[:100]
        
        p_x = softmax(x)
        p_y = softmax(y)
        
        # Create joint distribution approximation
        # This is a simple approach - real MI estimation is more complex
        joint_xy = np.outer(p_x, p_y)
        joint_xy = joint_xy / np.sum(joint_xy)
        
        # Compute marginals
        p_x_marginal = np.sum(joint_xy, axis=1)
        p_y_marginal = np.sum(joint_xy, axis=0)
        
        # Compute product distribution
        prod_xy = np.outer(p_x_marginal, p_y_marginal)
        
        # KL divergence between joint and product
        kl = 0
        for i in range(joint_xy.shape[0]):
            for j in range(joint_xy.shape[1]):
                if joint_xy[i, j] > 0 and prod_xy[i, j] > 0:
                    kl += joint_xy[i, j] * np.log(joint_xy[i, j] / prod_xy[i, j])
        
        # Scale to [0,1] range
        return min(1.0, max(0.0, kl))
    
    def histogram_estimate(self, x, y, bins=10):
        """Estimate MI using histogram method"""
        # Convert to numpy arrays
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
            
        # Flatten arrays
        x = np.array(x).flatten()[:100]  # Use first 100 dims
        y = np.array(y).flatten()[:100]
        
        # Create histograms
        h_x, _ = np.histogram(x, bins=bins, density=True)
        h_y, _ = np.histogram(y, bins=bins, density=True)
        h_xy, _, _ = np.histogram2d(x, y, bins=bins, density=True)
        
        # Compute MI
        mi = 0
        for i in range(bins):
            for j in range(bins):
                if h_xy[i, j] > 0 and h_x[i] > 0 and h_y[j] > 0:
                    mi += h_xy[i, j] * np.log(h_xy[i, j] / (h_x[i] * h_y[j]))
                    
        # Scale to [0,1] range
        return min(1.0, max(0.0, mi))
    
    def jensen_shannon_divergence(self, p, q):
        """Compute Jensen-Shannon divergence between distributions p and q"""
        # Convert to numpy arrays if needed
        if isinstance(p, torch.Tensor):
            p = p.detach().cpu().numpy()
        if isinstance(q, torch.Tensor):
            q = q.detach().cpu().numpy()
            
        # Ensure valid probability distributions
        p = np.array(p).flatten()
        q = np.array(q).flatten()
        
        # Add small constant to avoid log(0)
        p = np.maximum(p, 1e-10)
        q = np.maximum(q, 1e-10)
        
        # Normalize if needed
        if abs(np.sum(p) - 1.0) > 1e-6:
            p = p / np.sum(p)
        if abs(np.sum(q) - 1.0) > 1e-6:
            q = q / np.sum(q)
        
        # Compute midpoint distribution
        m = 0.5 * (p + q)
        
        # Compute JS divergence
        js_div = 0.5 * np.sum(p * np.log(p / m)) + 0.5 * np.sum(q * np.log(q / m))
        
        # Handle numerical issues
        js_div = min(1.0, max(0.0, js_div))
        
        return js_div


class MineEstimator(nn.Module):
    """Neural network estimator for mutual information"""
    def __init__(self, input_dim=100, hidden_dim=64):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Network for joint distribution
        self.joint_net = nn.Sequential(
            nn.Linear(input_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, x, y):
        """Forward pass to compute statistics for MI estimation"""
        # Ensure batch dimension
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if y.dim() == 1:
            y = y.unsqueeze(0)
            
        # Limit dimensions for efficiency
        x = x[:, :self.input_dim]
        y = y[:, :self.input_dim]
        
        # Get batch size
        batch_size = x.shape[0]
        
        # Concatenate x and y for joint statistics
        joint_input = torch.cat([x, y], dim=1)
        joint_output = self.joint_net(joint_input)
        
        # Create shuffled version of y for marginal statistics
        y_shuffled = y[torch.randperm(batch_size)]
        marg_input = torch.cat([x, y_shuffled], dim=1)
        marg_output = self.joint_net(marg_input)
        
        # Compute MINE lower bound: E_joint[T] - log(E_marg[e^T])
        # Here we use a simplified direct estimate
        mine_lb = torch.mean(joint_output) - torch.log(torch.mean(torch.exp(marg_output)))
        
        return mine_lb
    
    def estimate(self, x, y):
        """Estimate mutual information between x and y"""
        # Convert to torch tensors if needed
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.float32)
            
        # Ensure we have the right shapes
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if y.dim() == 1:
            y = y.unsqueeze(0)
            
        # Normalize inputs
        x = (x - x.mean()) / (x.std() + 1e-8)
        y = (y - y.mean()) / (y.std() + 1e-8)
        
        # Get estimate
        with torch.no_grad():
            mi_est = self.forward(x, y)
        
        # Scale to [0,1] using sigmoid
        scaled_mi = torch.sigmoid(mi_est).item()
        
        return scaled_mi


class InformationBottleneck:
    """Implementation of the Information Bottleneck principle"""
    def __init__(self, beta=1.0):
        self.beta = beta  # Trade-off parameter
        self.mi_estimator = MutualInformationEstimator()
    
    def compute_objective(self, x, y, z):
        """Compute IB objective: I(Z;Y) - beta * I(Z;X)"""
        # I(Z;Y) - relevance term (maximize)
        mi_zy = self.mi_estimator.estimate(z, y)
        
        # I(Z;X) - complexity term (minimize)
        mi_zx = self.mi_estimator.estimate(z, x)
        
        # IB objective
        ib_obj = mi_zy - self.beta * mi_zx
        
        return ib_obj, mi_zy, mi_zx
    
    def optimize_routing(self, module_outputs, task_relevance):
        """Optimize routing based on IB principle"""
        # module_outputs: dict of module outputs
        # task_relevance: measure of relevance to task
        
        # Initialize routing weights
        modules = list(module_outputs.keys())
        n_modules = len(modules)
        routing = np.ones((n_modules, n_modules)) / n_modules
        
        # Optimize routing matrix
        for i, src in enumerate(modules):
            for j, tgt in enumerate(modules):
                if i != j:
                    # Source info
                    x = module_outputs[src]
                    
                    # Target info
                    z = module_outputs[tgt]
                    
                    # Task relevance
                    y = task_relevance
                    
                    # Compute IB objective
                    ib_obj, mi_zy, mi_zx = self.compute_objective(x, y, z)
                    
                    # Set routing weight based on IB objective
                    # Higher objective = higher weight
                    routing[i, j] = np.exp(ib_obj)
        
        # Normalize routing matrix
        for i in range(n_modules):
            row_sum = routing[i].sum()
            if row_sum > 0:
                routing[i] = routing[i] / row_sum
        
        return routing
    
    def optimize_beta(self, x, y, z_generator, z_dims=None, beta_vals=None):
        """
        Find optimal beta parameter by sweeping through values and measuring the IB curve
        """
        if z_dims is None:
            z_dims = range(1, 10)
        if beta_vals is None:
            beta_vals = np.logspace(-2, 1, 20)  # Log-spaced beta values
        
        results = []
        for beta in beta_vals:
            self.beta = beta
            
            # Test different dimensions for Z
            for dim in z_dims:
                # Generate a compressed representation Z with the given dimension
                z = z_generator(x, dim)
                
                # Compute IB objective
                ib_obj, mi_zy, mi_zx = self.compute_objective(x, y, z)
                
                results.append({
                    'beta': beta,
                    'z_dim': dim,
                    'relevance': mi_zy,
                    'complexity': mi_zx,
                    'ib_objective': ib_obj
                })
        
        # Find the optimal beta value (maximizing IB objective)
        best_result = max(results, key=lambda r: r['ib_objective'])
        self.beta = best_result['beta']
        
        return results, best_result

    def compute_information_plane(self, x, y, z_generator, z_dims=None):
        """
        Compute the information plane (I(Z;X) vs I(Z;Y)) for different Z dimensions
        """
        if z_dims is None:
            z_dims = range(1, 50)
            
        plane_points = []
        
        for dim in z_dims:
            # Generate a compressed representation Z with the given dimension
            z = z_generator(x, dim)
            
            # Compute mutual information values
            mi_zx = self.mi_estimator.estimate(z, x)
            mi_zy = self.mi_estimator.estimate(z, y)
            
            plane_points.append({
                'z_dim': dim,
                'I_ZX': mi_zx,
                'I_ZY': mi_zy
            })
        
        return plane_points

    def _differentiable_mi(self, x, y, sigma=0.1):
        """Differentiable approximation of mutual information for training"""
        import torch
        
        # Use a Gaussian approximation of mutual information
        # Based on correlation between x and y
        
        # Flatten and normalize
        if isinstance(x, torch.Tensor):
            if x.dim() > 2:
                x = x.reshape(x.shape[0], -1)
            if y.dim() > 2:
                y = y.reshape(y.shape[0], -1)
                
            x_normalized = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
            y_normalized = (y - y.mean(dim=0)) / (y.std(dim=0) + 1e-8)
            
            # Compute correlation
            n = x.shape[0]
            corr_matrix = torch.matmul(x_normalized.t(), y_normalized) / n
            
            # Sum squared correlations (approximation of mutual information)
            mi_approx = 0.5 * torch.sum(corr_matrix ** 2)
            
            return mi_approx
        else:
            # For numpy arrays
            import numpy as np
            
            x_normalized = (x - np.mean(x, axis=0)) / (np.std(x, axis=0) + 1e-8)
            y_normalized = (y - np.mean(y, axis=0)) / (np.std(y, axis=0) + 1e-8)
            
            # Compute correlation
            n = x.shape[0]
            corr_matrix = np.matmul(x_normalized.T, y_normalized) / n
            
            # Sum squared correlations
            mi_approx = 0.5 * np.sum(corr_matrix ** 2)
            
            return mi_approx


#memory_dynamics_SQuAD_Without_LTM.py
import numpy as np
import time
import heapq
from collections import deque

class ShortTermMemory:
    """Short-term memory with temporal decay"""
    def __init__(self, capacity=20, decay_rate=0.1, dim=256):
        self.capacity = capacity
        self.decay_rate = decay_rate
        self.items = deque(maxlen=capacity)  # (text, embedding, timestamp, activation)
        self.dim = dim
        
    def add(self, text, embedding):
        """Add a new item to short-term memory"""
        timestamp = time.time()
        activation = 1.0  # Full activation for new items
        self.items.append((text, embedding, timestamp, activation))
        
    def retrieve(self, query_embedding, top_k=3):
        """Retrieve most relevant items considering decay"""
        now = time.time()
        
        # Calculate similarity and apply temporal decay
        scored_items = []
        for text, emb, timestamp, activation in self.items:
            # Compute similarity
            similarity = self.cosine_similarity(emb, query_embedding)
            
            # Apply temporal decay based on time elapsed
            time_elapsed = now - timestamp
            decay_factor = np.exp(-self.decay_rate * time_elapsed)
            
            # Final score combines similarity and decay
            score = similarity * decay_factor * activation
            scored_items.append((score, text))
        
        # Get top-k items
        scored_items.sort(reverse=True)
        top_items = scored_items[:top_k]
        
        # Update activations of retrieved items (rehearsal)
        self._update_activations([text for _, text in top_items])
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def _update_activations(self, retrieved_texts):
        """Update activation levels based on retrieval (rehearsal)"""
        for i, (text, emb, timestamp, activation) in enumerate(self.items):
            if text in retrieved_texts:
                # Boost activation for retrieved items (rehearsal)
                new_activation = min(1.0, activation + 0.2)
                self.items[i] = (text, emb, timestamp, new_activation)
            else:
                # Natural decay for non-retrieved items
                new_activation = max(0.1, activation - 0.05)
                self.items[i] = (text, emb, timestamp, new_activation)
    
    def apply_decay(self):
        """Apply temporal decay to all items based on elapsed time"""
        now = time.time()
        decayed_items = []
        
        for text, emb, timestamp, activation in self.items:
            time_elapsed = now - timestamp
            decay_factor = np.exp(-self.decay_rate * time_elapsed)
            new_activation = activation * decay_factor
            
            # Only keep items with sufficient activation
            if new_activation > 0.1:
                decayed_items.append((text, emb, timestamp, new_activation))
        
        self.items = deque(decayed_items, maxlen=self.capacity)
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.items:
            return np.zeros(self.dim)
        
        # Average embeddings weighted by activation
        embeddings = np.array([emb * activation for _, emb, _, activation in self.items])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())  # or similarity[0] if you prefer the first element
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, perception_input, dt=0.1):
        """
        Update the short-term memory state using the differential equation:
        dS_M^ST(t)/dt = F_M^ST(S_M^ST(t), S_P(t)) - (1/tau_ST) * S_M^ST(t) + xi_M^ST(t)
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            self.stochastic_terms = 0.01 * np.random.randn(self.dim)
            
            # Define the internal dynamics function F_M^ST
            def F_M_ST(state, perception):
                # This function defines how perception influences STM
                interaction = 0.7 * state + 0.3 * perception
                return np.tanh(interaction)  # Nonlinearity for stability
            
            # Apply the STM differential equation
            decay_term = self.state / self.decay_rate  # Using decay_rate as tau_ST
            dynamics_term = F_M_ST(self.state, perception_input)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (dynamics_term - decay_term + self.stochastic_terms)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state


class LongTermMemory:
    """Simplified Long-term memory stub that maintains interface but doesn't store data"""
    def __init__(self, capacity=100, consolidation_rate=0.01, dim=256):
        self.capacity = capacity
        self.consolidation_rate = consolidation_rate
        self.dim = dim
        # No items list - we don't store anything
    
    def consolidate(self, stm, embedding, importance_score):
        """Stub method that does nothing"""
        pass
    
    def retrieve(self, query_embedding, top_k=2):
        """Return empty output since we don't store anything"""
        return ""
    
    def get_state(self):
        """Return zeros as state since we don't store anything"""
        return np.zeros(self.dim)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, stm_state, dt=0.1):
        """Stub method that just returns the state"""
        return self.state


class DeepSubconsciousMemory:
    """Deep subconscious memory with emotional and procedural patterns"""
    def __init__(self, access_coef=0.001, dim=256):
        self.access_coef = access_coef
        self.dim = dim
        self.patterns = []  # (pattern_embedding, emotional_weight, update_count)
        self.core_knowledge = [  # Built-in reading comprehension knowledge
            "Focus on the specific text span that answers the question",
            "Pay attention to relevant entities mentioned in both question and context",
            "Consider temporal relationships and sequences of events in the passage",
            "Look for explicit factual details that match the question requirements",
            "Definitions are often provided directly in the context"
        ]
        self.core_embeddings = None  # Will be initialized on first access
        
    def update(self, embedding, ltm_output, emotional_salience):
        """Update deep memory based on input and emotional salience"""
        # Only make updates at the slow rate determined by access coefficient
        if np.random.random() > self.access_coef:
            return
            
        # Find similar patterns
        similarities = []
        for i, (pattern, weight, count) in enumerate(self.patterns):
            similarity = self.cosine_similarity(pattern, embedding)
            similarities.append((similarity, i))
        
        # Sort by similarity
        similarities.sort(reverse=True)
        
        if similarities and similarities[0][0] > 0.8:
            # Update existing pattern
            _, idx = similarities[0]
            pattern, weight, count = self.patterns[idx]
            
            # Slow integration of new information
            update_rate = 0.1
            updated_pattern = (1 - update_rate) * pattern + update_rate * embedding
            updated_pattern = updated_pattern / np.linalg.norm(updated_pattern)
            
            # Update emotional weight and count
            new_weight = (weight * count + emotional_salience) / (count + 1)
            new_count = count + 1
            
            self.patterns[idx] = (updated_pattern, new_weight, new_count)
        else:
            # Add new pattern if emotional salience is high enough
            if emotional_salience > 0.3:
                self.patterns.append((embedding, emotional_salience, 1))
    
    def retrieve(self, query_embedding, top_k=3):
        """Retrieve from deep subconscious memory"""
        # Initialize core embeddings if needed
        if self.core_embeddings is None:
            # In a real implementation, these would be actual embeddings
            # Here we just create random embeddings as placeholders
            self.core_embeddings = [np.random.rand(self.dim) for _ in self.core_knowledge]
        
        # Get relevant core knowledge
        core_scores = []
        for i, emb in enumerate(self.core_embeddings):
            similarity = self.cosine_similarity(emb, query_embedding)
            core_scores.append((similarity, self.core_knowledge[i]))
        
        # Get relevant patterns
        pattern_scores = []
        for pattern, weight, _ in self.patterns:
            # Compute similarity
            similarity = self.cosine_similarity(pattern, query_embedding)
            
            # Final score combines similarity and emotional weight
            score = similarity * weight
            pattern_scores.append((score, f"Pattern with emotional significance: {weight:.2f}"))
        
        # Combine and get top results
        all_scores = core_scores + pattern_scores
        all_scores.sort(reverse=True)
        top_items = all_scores[:top_k]
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.patterns:
            return np.zeros(self.dim)
        
        # Average pattern embeddings weighted by emotional weight
        embeddings = np.array([pattern * weight for pattern, weight, _ in self.patterns])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())  # or similarity[0] if you prefer the first element
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for deep subconscious dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.emotional_state = np.zeros(5)  # 5 basic emotions

    def update_dynamics(self, ltm_state, emotional_input, dt=None):
        """
        Update the deep subconscious memory using the equation:
        S_M^DS(t+1) = S_M^DS(t) + alpha_DS · Psi(S_M^LT(t), S_M^DS(t), S_E(t))
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        # Deep subconscious updates very slowly
        if dt > 1.0:  # Much longer time scale than STM/LTM
            # Update emotional state
            self.emotional_state = 0.9 * self.emotional_state + 0.1 * emotional_input
            
            # Define the deep integration function Psi
            def Psi(ltm, dsm, emotional):
                # Calculate emotional salience
                emotional_magnitude = np.linalg.norm(emotional)
                
                # Identify patterns that occur with high emotional salience
                emotional_patterns = ltm * emotional_magnitude
                
                # Slow integration of emotionally salient patterns
                return 0.1 * emotional_patterns - 0.01 * dsm
            
            # Apply the DSM update equation
            integration_term = Psi(ltm_state, self.state, self.emotional_state)
            
            # Update step
            self.previous_state = self.state.copy()
            self.state = self.state + self.access_coef * integration_term
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state

#memory_dynamics_SQuAD_Without_Memory.py
import numpy as np
import time

class ShortTermMemory:
    """Stub Short-term memory that maintains interface but doesn't store data"""
    def __init__(self, capacity=20, decay_rate=0.1, dim=256):
        self.capacity = capacity
        self.decay_rate = decay_rate
        self.dim = dim
        # No items list - we don't store anything
    
    def add(self, text, embedding):
        """Stub method that does nothing"""
        pass
    
    def retrieve(self, query_embedding, top_k=3):
        """Return empty output since we don't store anything"""
        return ""
    
    def _update_activations(self, retrieved_texts):
        """Stub method that does nothing"""
        pass
    
    def apply_decay(self):
        """Stub method that does nothing"""
        pass
    
    def get_state(self):
        """Return zeros as state since we don't store anything"""
        return np.zeros(self.dim)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, perception_input, dt=0.1):
        """Stub method that just returns zeros"""
        return np.zeros(self.dim)


class LongTermMemory:
    """Stub Long-term memory that maintains interface but doesn't store data"""
    def __init__(self, capacity=100, consolidation_rate=0.01, dim=256):
        self.capacity = capacity
        self.consolidation_rate = consolidation_rate
        self.dim = dim
        # No items list - we don't store anything
    
    def consolidate(self, stm, embedding, importance_score):
        """Stub method that does nothing"""
        pass
    
    def retrieve(self, query_embedding, top_k=2):
        """Return empty output since we don't store anything"""
        return ""
    
    def _has_similar(self, embedding, similarity_threshold=0.9):
        """Stub method that returns False"""
        return False
    
    def get_state(self):
        """Return zeros as state since we don't store anything"""
        return np.zeros(self.dim)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, stm_state, dt=0.1):
        """Stub method that just returns zeros"""
        return np.zeros(self.dim)


class DeepSubconsciousMemory:
    """Stub Deep subconscious memory that maintains interface but doesn't store data"""
    def __init__(self, access_coef=0.001, dim=256):
        self.access_coef = access_coef
        self.dim = dim
        # No patterns list - we don't store anything
        self.core_knowledge = []  # Empty core knowledge
        self.core_embeddings = None
    
    def update(self, embedding, ltm_output, emotional_salience):
        """Stub method that does nothing"""
        pass
    
    def retrieve(self, query_embedding, top_k=3):
        """Return empty output since we don't store anything"""
        return ""
    
    def get_state(self):
        """Return zeros as state since we don't store anything"""
        return np.zeros(self.dim)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for deep subconscious dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.emotional_state = np.zeros(5)  # 5 basic emotions

    def update_dynamics(self, ltm_state, emotional_input, dt=None):
        """Stub method that just returns zeros"""
        return np.zeros(self.dim)

#memory_dynamics_SQuAD_Without_STM.py
import numpy as np
import time
import heapq
from collections import deque

class ShortTermMemory:
    """Stub Short-term memory that maintains interface but doesn't store data"""
    def __init__(self, capacity=20, decay_rate=0.1, dim=256):
        self.capacity = capacity
        self.decay_rate = decay_rate
        self.dim = dim
        # No items list - we don't store anything
    
    def add(self, text, embedding):
        """Stub method that does nothing"""
        pass
    
    def retrieve(self, query_embedding, top_k=3):
        """Return empty output since we don't store anything"""
        return ""
    
    def _update_activations(self, retrieved_texts):
        """Stub method that does nothing"""
        pass
    
    def apply_decay(self):
        """Stub method that does nothing"""
        pass
    
    def get_state(self):
        """Return zeros as state since we don't store anything"""
        return np.zeros(self.dim)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, perception_input, dt=0.1):
        """Stub method that just returns zeros"""
        return np.zeros(self.dim)


class LongTermMemory:
    """Long-term memory with importance-based consolidation"""
    def __init__(self, capacity=100, consolidation_rate=0.01, dim=256):
        self.capacity = capacity
        self.consolidation_rate = consolidation_rate
        self.dim = dim
        self.items = []  # (text, embedding, importance, creation_time)
        
    def consolidate(self, stm, embedding, importance_score):
        """Consolidate item from STM to LTM based on importance"""
        # Since STM is disabled, we'll directly add the embedding
        # Only consolidate if importance exceeds threshold
        consolidation_threshold = 0.5
        if importance_score > consolidation_threshold:
            # Add to LTM with a generic text representation since STM is disabled
            text = f"Consolidated memory item at {time.time()}"
            creation_time = time.time()
            
            # Check if similar item already exists in LTM
            if not self._has_similar(embedding):
                # Add to LTM
                self.items.append((text, embedding, importance_score, creation_time))
                
                # Keep within capacity
                if len(self.items) > self.capacity:
                    # Remove least important item
                    self.items.sort(key=lambda x: x[2])
                    self.items.pop(0)
    
    def retrieve(self, query_embedding, top_k=2):
        """Retrieve most relevant items from long-term memory"""
        # Calculate similarity
        scored_items = []
        for text, emb, importance, _ in self.items:
            # Compute similarity
            similarity = self.cosine_similarity(emb, query_embedding)
            
            # Final score combines similarity and importance
            score = similarity * (0.5 + 0.5 * importance)
            scored_items.append((score, text))
        
        # Get top-k items
        scored_items.sort(reverse=True)
        top_items = scored_items[:top_k]
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def _has_similar(self, embedding, similarity_threshold=0.9):
        """Check if a similar item already exists in memory"""
        for _, emb, _, _ in self.items:
            similarity = self.cosine_similarity(emb, embedding)
            if similarity > similarity_threshold:
                return True
        return False
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.items:
            return np.zeros(self.dim)
        
        # Average embeddings weighted by importance
        embeddings = np.array([emb * importance for _, emb, importance, _ in self.items])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, stm_state, dt=0.1):
        """
        Update the long-term memory state using the differential equation:
        dS_M^LT(t)/dt = eta_LT · Phi(S_M^ST(t), S_M^LT(t)) + xi_M^LT(t)
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            self.stochastic_terms = 0.005 * np.random.randn(self.dim)
            
            # Define the consolidation function Phi - but use zeros for STM
            # since STM is disabled in this ablation study
            def Phi(stm, ltm):
                # No consolidation happens (STM is zeros)
                return np.zeros_like(ltm)
            
            # Apply the LTM differential equation
            consolidation_term = Phi(stm_state, self.state)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (self.consolidation_rate * consolidation_term + self.stochastic_terms)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state


class DeepSubconsciousMemory:
    """Deep subconscious memory with emotional and procedural patterns"""
    def __init__(self, access_coef=0.001, dim=256):
        self.access_coef = access_coef
        self.dim = dim
        self.patterns = []  # (pattern_embedding, emotional_weight, update_count)
        self.core_knowledge = [  # Built-in reading comprehension knowledge
            "Focus on the specific text span that answers the question",
            "Pay attention to relevant entities mentioned in both question and context",
            "Consider temporal relationships and sequences of events in the passage",
            "Look for explicit factual details that match the question requirements",
            "Definitions are often provided directly in the context"
        ]
        self.core_embeddings = None  # Will be initialized on first access
        
    def update(self, embedding, ltm_output, emotional_salience):
        """Update deep memory based on input and emotional salience"""
        # Only make updates at the slow rate determined by access coefficient
        if np.random.random() > self.access_coef:
            return
            
        # Find similar patterns
        similarities = []
        for i, (pattern, weight, count) in enumerate(self.patterns):
            similarity = self.cosine_similarity(pattern, embedding)
            similarities.append((similarity, i))
        
        # Sort by similarity
        similarities.sort(reverse=True)
        
        if similarities and similarities[0][0] > 0.8:
            # Update existing pattern
            _, idx = similarities[0]
            pattern, weight, count = self.patterns[idx]
            
            # Slow integration of new information
            update_rate = 0.1
            updated_pattern = (1 - update_rate) * pattern + update_rate * embedding
            updated_pattern = updated_pattern / np.linalg.norm(updated_pattern)
            
            # Update emotional weight and count
            new_weight = (weight * count + emotional_salience) / (count + 1)
            new_count = count + 1
            
            self.patterns[idx] = (updated_pattern, new_weight, new_count)
        else:
            # Add new pattern if emotional salience is high enough
            if emotional_salience > 0.3:
                self.patterns.append((embedding, emotional_salience, 1))
    
    def retrieve(self, query_embedding, top_k=3):
        """Retrieve from deep subconscious memory"""
        # Initialize core embeddings if needed
        if self.core_embeddings is None:
            # In a real implementation, these would be actual embeddings
            # Here we just create random embeddings as placeholders
            self.core_embeddings = [np.random.rand(self.dim) for _ in self.core_knowledge]
        
        # Get relevant core knowledge
        core_scores = []
        for i, emb in enumerate(self.core_embeddings):
            similarity = self.cosine_similarity(emb, query_embedding)
            core_scores.append((similarity, self.core_knowledge[i]))
        
        # Get relevant patterns
        pattern_scores = []
        for pattern, weight, _ in self.patterns:
            # Compute similarity
            similarity = self.cosine_similarity(pattern, query_embedding)
            
            # Final score combines similarity and emotional weight
            score = similarity * weight
            pattern_scores.append((score, f"Pattern with emotional significance: {weight:.2f}"))
        
        # Combine and get top results
        all_scores = core_scores + pattern_scores
        all_scores.sort(reverse=True)
        top_items = all_scores[:top_k]
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.patterns:
            return np.zeros(self.dim)
        
        # Average pattern embeddings weighted by emotional weight
        embeddings = np.array([pattern * weight for pattern, weight, _ in self.patterns])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for deep subconscious dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.emotional_state = np.zeros(5)  # 5 basic emotions

    def update_dynamics(self, ltm_state, emotional_input, dt=None):
        """
        Update the deep subconscious memory using the equation:
        S_M^DS(t+1) = S_M^DS(t) + alpha_DS · Psi(S_M^LT(t), S_M^DS(t), S_E(t))
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        # Deep subconscious updates very slowly
        if dt > 1.0:  # Much longer time scale than STM/LTM
            # Update emotional state
            self.emotional_state = 0.9 * self.emotional_state + 0.1 * emotional_input
            
            # Define the deep integration function Psi
            def Psi(ltm, dsm, emotional):
                # Calculate emotional salience
                emotional_magnitude = np.linalg.norm(emotional)
                
                # Identify patterns that occur with high emotional salience
                emotional_patterns = ltm * emotional_magnitude
                
                # Slow integration of emotionally salient patterns
                return 0.1 * emotional_patterns - 0.01 * dsm
            
            # Apply the DSM update equation
            integration_term = Psi(ltm_state, self.state, self.emotional_state)
            
            # Update step
            self.previous_state = self.state.copy()
            self.state = self.state + self.access_coef * integration_term
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state

#memory_dynamics_SQuAD_Without_DSM.py
import numpy as np
import time
import heapq
from collections import deque

class ShortTermMemory:
    """Short-term memory with activation-based retrieval and decay"""
    def __init__(self, capacity=20, decay_rate=0.1, dim=256):
        self.capacity = capacity
        self.decay_rate = decay_rate
        self.dim = dim
        self.items = []  # (text, embedding, activation, timestamp)
        
    def add(self, text, embedding):
        """Add item to short-term memory"""
        timestamp = time.time()
        
        # Check if similar item already exists
        for i, (existing_text, existing_emb, activation, ts) in enumerate(self.items):
            similarity = self.cosine_similarity(embedding, existing_emb)
            if similarity > 0.9:  # High similarity threshold
                # Update activation instead of adding duplicate
                self.items[i] = (existing_text, existing_emb, min(1.0, activation + 0.2), timestamp)
                return
        
        # Add new item with initial activation
        self.items.append((text, embedding, 1.0, timestamp))
        
        # Keep only most recent/activated items within capacity
        if len(self.items) > self.capacity:
            # Sort by activation and keep top items
            self.items.sort(key=lambda x: x[2], reverse=True)
            self.items = self.items[:self.capacity]
    
    def retrieve(self, query_embedding, top_k=3):
        """Retrieve most relevant items from short-term memory"""
        if not self.items:
            return ""
        
        # Calculate relevance scores
        scored_items = []
        for text, emb, activation, _ in self.items:
            # Compute similarity
            similarity = self.cosine_similarity(emb, query_embedding)
            
            # Final score combines similarity and activation
            score = similarity * activation
            scored_items.append((score, text))
        
        # Get top-k items
        scored_items.sort(reverse=True)
        top_items = scored_items[:top_k]
        
        # Update activations for retrieved items
        retrieved_texts = [text for _, text in top_items]
        self._update_activations(retrieved_texts)
        
        # Format output
        return "\n".join(retrieved_texts)
    
    def _update_activations(self, retrieved_texts):
        """Boost activation for retrieved items"""
        for i, (text, emb, activation, ts) in enumerate(self.items):
            if text in retrieved_texts:
                # Boost activation
                self.items[i] = (text, emb, min(1.0, activation + 0.1), ts)
    
    def apply_decay(self):
        """Apply time-based decay to all items"""
        current_time = time.time()
        new_items = []
        
        for text, emb, activation, ts in self.items:
            # Calculate time-based decay
            time_elapsed = current_time - ts
            decay_factor = np.exp(-self.decay_rate * time_elapsed)
            new_activation = activation * decay_factor
            
            # Keep items with sufficient activation
            if new_activation > 0.1:
                new_items.append((text, emb, new_activation, ts))
        
        self.items = new_items
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.items:
            return np.zeros(self.dim)
        
        # Average embeddings weighted by activation
        embeddings = np.array([emb * activation for _, emb, activation, _ in self.items])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, perception_input, dt=0.1):
        """
        Update the short-term memory state using the differential equation:
        dS_M^ST(t)/dt = -gamma_ST · S_M^ST(t) + beta_ST · P(t) + xi_M^ST(t)
        
        where:
        - gamma_ST: decay rate
        - beta_ST: input strength
        - P(t): perception input
        - xi_M^ST(t): stochastic noise
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            self.stochastic_terms = 0.01 * np.random.randn(self.dim)
            
            # Decay term
            decay_term = -self.decay_rate * self.state
            
            # Input term (perception)
            input_strength = 0.5
            input_term = input_strength * perception_input
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (decay_term + input_term + self.stochastic_terms)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state


class LongTermMemory:
    """Long-term memory with importance-based consolidation"""
    def __init__(self, capacity=100, consolidation_rate=0.01, dim=256):
        self.capacity = capacity
        self.consolidation_rate = consolidation_rate
        self.dim = dim
        self.items = []  # (text, embedding, importance, creation_time)
        
    def consolidate(self, stm, embedding, importance_score):
        """Consolidate item from STM to LTM based on importance"""
        # Only consolidate if importance exceeds threshold
        consolidation_threshold = 0.5
        if importance_score > consolidation_threshold:
            # Get text from STM if available
            text = f"Consolidated memory item at {time.time()}"
            if stm.items:
                # Try to find matching item in STM
                for stm_text, stm_emb, _, _ in stm.items:
                    similarity = self.cosine_similarity(stm_emb, embedding)
                    if similarity > 0.8:
                        text = stm_text
                        break
            
            creation_time = time.time()
            
            # Check if similar item already exists in LTM
            if not self._has_similar(embedding):
                # Add to LTM
                self.items.append((text, embedding, importance_score, creation_time))
                
                # Keep within capacity
                if len(self.items) > self.capacity:
                    # Remove least important item
                    self.items.sort(key=lambda x: x[2])
                    self.items.pop(0)
    
    def retrieve(self, query_embedding, top_k=2):
        """Retrieve most relevant items from long-term memory"""
        if not self.items:
            return ""
        
        # Calculate similarity
        scored_items = []
        for text, emb, importance, _ in self.items:
            # Compute similarity
            similarity = self.cosine_similarity(emb, query_embedding)
            
            # Final score combines similarity and importance
            score = similarity * (0.5 + 0.5 * importance)
            scored_items.append((score, text))
        
        # Get top-k items
        scored_items.sort(reverse=True)
        top_items = scored_items[:top_k]
        
        # Format output
        return "\n".join([text for _, text in top_items])
    
    def _has_similar(self, embedding, similarity_threshold=0.9):
        """Check if a similar item already exists in memory"""
        for _, emb, _, _ in self.items:
            similarity = self.cosine_similarity(emb, embedding)
            if similarity > similarity_threshold:
                return True
        return False
    
    def get_state(self):
        """Get current state representation of memory"""
        if not self.items:
            return np.zeros(self.dim)
        
        # Average embeddings weighted by importance
        embeddings = np.array([emb * importance for _, emb, importance, _ in self.items])
        return np.mean(embeddings, axis=0)
    
    def cosine_similarity(self, a, b):
        """Calculate cosine similarity between two vectors"""
        dot_product = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        similarity = dot_product / (norm_a * norm_b)
        
        # Convert array to scalar if necessary
        if isinstance(similarity, np.ndarray):
            similarity = float(similarity.mean())
        
        return similarity
    
    def initialize_dynamics(self, initial_state=None, dim=None):
        """Initialize the state for differential equation dynamics"""
        if dim is None:
            dim = self.dim
        
        self.state = initial_state if initial_state is not None else np.zeros(dim)
        self.previous_state = self.state.copy()
        self.previous_time = time.time()
        self.stochastic_terms = np.zeros(dim)

    def update_dynamics(self, stm_state, dt=0.1):
        """
        Update the long-term memory state using the differential equation:
        dS_M^LT(t)/dt = eta_LT · Phi(S_M^ST(t), S_M^LT(t)) + xi_M^LT(t)
        """
        current_time = time.time()
        
        # Only update if enough time has passed
        if dt is None:
            dt = current_time - self.previous_time
        
        if dt > 0.01:  # Minimum time step
            # Generate stochastic perturbation (noise)
            self.stochastic_terms = 0.005 * np.random.randn(self.dim)
            
            # Define the consolidation function Phi
            def Phi(stm, ltm):
                # Consolidation happens when STM has strong signal
                stm_strength = np.linalg.norm(stm)
                if stm_strength > 0.1:
                    # Transfer information from STM to LTM
                    return stm - 0.5 * ltm
                else:
                    # No consolidation, just maintain LTM
                    return -0.01 * ltm
            
            # Apply the LTM differential equation
            consolidation_term = Phi(stm_state, self.state)
            
            # Euler integration step
            self.previous_state = self.state.copy()
            self.state = self.state + dt * (self.consolidation_rate * consolidation_term + self.stochastic_terms)
            
            # Store time for next update
            self.previous_time = current_time
        
        return self.state


#dnon_modules_squad_Without_DSM.py
import json
import time
import logging
import numpy as np
import boto3
import torch
from scipy.special import softmax
from manifold_representation_StrategyQA import RiemannianManifold, EuclideanEmbedding
from memory_dynamics_StrategyQA import ShortTermMemory, LongTermMemory
from information_geometry_StrategyQA import FisherInformationMatrix, NaturalGradient

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BaseAgent:
    def __init__(self, state_dim=256, temperature=0.02):
        self.MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
        self.temperature = temperature
        self.bedrock_runtime = None
        self.latency_stats = []
        
        # Add manifold representation
        self.manifold = RiemannianManifold(dim=state_dim)
        self.state = None
        self.state_dim = state_dim
        self.state_history = []
        
        # Add Fisher information matrix for natural gradient
        self.fisher_matrix = FisherInformationMatrix(dim=state_dim)
        
    def set_bedrock_client(self, bedrock_client):
        self.bedrock_runtime = bedrock_client

    def query(self, user_input):
        """Query the Claude model with the given input"""
        logger.info(f"Query sent to {self.__class__.__name__}: {user_input[:50]}...")
        
        # Start timing
        start_time = time.time()

        try:
            response = self.bedrock_runtime.invoke_model(
                modelId=self.MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 4090,
                    "messages": [{"role": "user", "content": [
                        {"type": "text", "text": f"{self.system_prompt}\n\n{user_input}"}
                    ]}],
                    "temperature": self.temperature,
                    "top_p": 1,
                })
            )
            
            response_body = json.loads(response['body'].read())
            response_text = response_body['content'][0]['text']
            
            # Calculate and record latency
            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            self.latency_stats.append(latency_ms)
            
            # Update state manifold with embedding of response
            embedding = self.generate_embedding(response_text)
            self.update_state(embedding)
            
            logger.info(f"Response from {self.__class__.__name__}: {response_text[:100]}... ({latency_ms:.2f}ms)")
            return response_text

        except Exception as e:
            logger.error(f"Error querying model: {e}")
            return f"An error occurred: {e}"

    def generate_embedding(self, text):
        """Generate embedding using Cohere model, handling large text inputs"""
        try:
            # Check if text is too long (using Cohere's 2048 limit)
            if len(text) > 2000:  # Set to 2000 to be safe
                # Split text into chunks of ~2000 chars
                chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
                logger.info(f"Text length {len(text)} chars, split into {len(chunks)} chunks")
                
                # Get embeddings for each chunk (limit to first 5 chunks)
                chunk_embeddings = []
                for i, chunk in enumerate(chunks[:5]):  # Only use first 5 chunks
                    try:
                        response = self.bedrock_runtime.invoke_model(
                            modelId="cohere.embed-multilingual-v3",
                            contentType="application/json",
                            body=json.dumps({"texts": [chunk], "input_type": "search_document"})
                        )
                        chunk_embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
                        chunk_embeddings.append(chunk_embedding)
                    except Exception as chunk_e:
                        logger.error(f"Error embedding chunk {i}: {chunk_e}")
                
                # Average the chunk embeddings
                if chunk_embeddings:
                    embedding = np.mean(chunk_embeddings, axis=0)
                else:
                    # Fallback to random if all chunks failed
                    embedding = np.random.rand(1024)
            else:
                # Standard case for shorter text
                response = self.bedrock_runtime.invoke_model(
                    modelId="cohere.embed-multilingual-v3",
                    contentType="application/json",
                    body=json.dumps({"texts": [text], "input_type": "search_document"})
                )
                embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
            
            # Handle dimension mismatch through chunking/processing
            if len(embedding) != self.state_dim:
                # Process in chunks to create a compressed representation
                chunks = len(embedding) // self.state_dim
                compressed_embedding = np.zeros(self.state_dim)
                
                for i in range(self.state_dim):
                    start_idx = i * chunks
                    end_idx = min(start_idx + chunks, len(embedding))
                    compressed_embedding[i] = np.mean(embedding[start_idx:end_idx])
                    
                return compressed_embedding
            
            return embedding
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            return np.random.rand(self.state_dim)
    
    def update_state(self, embedding):
        """Update agent state on the manifold"""
        if self.state is None:
            self.state = EuclideanEmbedding(embedding)
        else:
            # Project the embedding onto the manifold
            manifold_point = self.manifold.project(embedding)
            
            # Update state using information geometry
            self.state = self.manifold.geodesic(self.state, manifold_point, step_size=0.5)
        
        # Store state history for convergence analysis
        self.state_history.append(self.state.copy())
    
    def get_state(self):
        """Get current state vector"""
        if self.state is None:
            return np.zeros(self.state_dim)
        return self.state.coordinates

    def compute_fisher_information(self):
        """Compute Fisher information matrix based on state history"""
        if len(self.state_history) > 1:
            self.fisher_matrix.update([sh.coordinates for sh in self.state_history])
        return self.fisher_matrix.get_matrix()


class PerceptionAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized perception module for reading comprehension on the SQuAD dataset. Your task is to:

1. Parse the given context and question with extreme precision
2. Extract ALL key entities, concepts, and their relationships from both the context and question
3. Identify the type of information being requested:
   - FACTUAL DETAILS: specific facts mentioned in the context
   - ENTITY INFORMATION: details about people, places, organizations
   - TEMPORAL RELATIONSHIPS: when events occurred
   - CAUSAL RELATIONSHIPS: why something happened
   - DEFINITIONS: what something means or refers to

4. Structure the information to make answer extraction easier by:
   - Highlighting sentences that likely contain the answer
   - Noting connections between the question entities and context
   - Identifying potential answer spans in the context

Output a structured representation of the context and question that will aid in answering the question accurately.
DO NOT answer the question yourself. Focus ONLY on structuring it for precise reasoning.
"""


class MemoryAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02, 
                st_decay_rate=0.1, lt_consolidation_rate=0.01):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized memory module for reading comprehension on the SQuAD dataset. Your job is to:

1. Provide ALL relevant knowledge needed to answer the question accurately
2. Focus primarily on information from the provided context
3. Supplement with background knowledge only when necessary for clarification
4. For named entities in the context, provide relevant facts about them
5. For specialized terminology, provide brief definitions if they appear in the context
6. Identify connections between concepts in the question and those in the context
7. Highlight the most relevant passages that likely contain the answer

Present this information in clear, factual statements organized by topic.
DO NOT solve the problem or provide the final answer - only provide the relevant knowledge.
"""
        
        # Two-tier memory implementation (without DSM)
        self.stm = ShortTermMemory(capacity=20, decay_rate=st_decay_rate, dim=state_dim)
        self.ltm = LongTermMemory(capacity=100, consolidation_rate=lt_consolidation_rate, dim=state_dim)
        
    def process_with_memory(self, query, facts=None):
        """Process query using the two-tier memory system with optional facts injection"""
        # 1. Generate embedding for the query
        query_embedding = self.generate_embedding(query)
        
        # 2. Process with memory components
        stm_output = self.stm.retrieve(query_embedding)
        ltm_output = self.ltm.retrieve(query_embedding)
        
        # 3. Consolidate memory (move from STM to LTM if important)
        importance_score = self.calculate_importance(query)
        self.ltm.consolidate(self.stm, query_embedding, importance_score)
        
        # 4. Update short-term memory
        self.stm.add(query, query_embedding)
        
        # Build memory context from the memory systems
        memory_context = f"""
        SHORT-TERM MEMORY (recent questions):
        {stm_output}

        LONG-TERM MEMORY (consolidated knowledge):
        {ltm_output}
        """
        
        # If facts are provided, include them in the context
        if facts and len(facts) > 0:
            facts_text = "\n".join([f"- {fact}" for fact in facts])
            memory_context += f"\n\nRELEVANT FACTS:\n{facts_text}"
        
        # Query the LLM with enhanced context
        enhanced_query = f"{query}\n\nUse the following knowledge:\n{memory_context}"
        return self.query(enhanced_query), memory_context
    
    def calculate_importance(self, query):
        """Calculate importance score for memory consolidation"""
        # Simple heuristic based on query complexity
        importance = min(0.9, 0.3 + 0.1 * len(query.split()))
        
        # Check for keywords indicating importance
        important_keywords = ["define", "who", "when", "where", "why", "how", "explain"]
        if any(keyword in query.lower() for keyword in important_keywords):
            importance += 0.2
            
        return min(1.0, importance)
    
    def get_state(self):
        """Get current state vector"""
        stm_state = self.stm.get_state() if hasattr(self.stm, 'get_state') else np.zeros(self.state_dim)
        ltm_state = self.ltm.get_state() if hasattr(self.ltm, 'get_state') else np.zeros(self.state_dim)
        
        # Combine STM and LTM states with weighted average
        combined_state = 0.7 * stm_state + 0.3 * ltm_state
        return combined_state


class ReasoningAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized reasoning module for reading comprehension on the SQuAD dataset. Your job is to:

1. Process the structured question and relevant knowledge with extreme precision
2. Apply methodical, step-by-step logical reasoning to extract the answer from the context
3. Focus on the specific information requested in the question
4. Locate the precise span of text in the context that answers the question
5. Verify that your answer is directly supported by the context
6. Ensure your answer is as concise as possible while fully answering the question

Follow this exact reasoning structure:
Step 1: Identify the type of information being requested
Step 2: Locate relevant passages in the context
Step 3: Extract candidate answer spans
Step 4: Verify each candidate against the question
Step 5: Select the most precise answer span

Final Answer: [Extracted answer span from the context]

Be extremely precise. Your answer should be the exact text span from the context that answers the question.
"""


class ExecutiveAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.01):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are an executive decision-making module that integrates outputs from three specialized modules for SQuAD questions:

1. Perception Module: How the question and context are structured
2. Memory Module: Relevant knowledge and facts
3. Reasoning Module: Step-by-step logical reasoning

Given the module weights, make the final determination by following this process:

1. Carefully review the context and question
2. Examine the reasoning process used to extract the answer
3. Verify that the answer is directly supported by the context
4. Ensure the answer is the most precise span of text that fully answers the question
5. If multiple answer candidates exist, select the one with the strongest support

Return ONLY the answer text, with no additional explanation. Your response should be the exact span of text from the context that answers the question.
"""
        
        # Initialize attention weights for modules
        self.module_attention_weights = np.array([0.33, 0.33, 0.33])  # P, M, R
        
    def update_attention_weights(self, routing_weights):
        """Update attention weights based on routing matrix"""
        # Extract relevant weights from routing matrix
        p_weight = routing_weights['perception']
        m_weight = routing_weights['memory']
        r_weight = routing_weights['reasoning']
        
        # Update module attention weights with exponential moving average
        alpha = 0.7  # EMA factor
        new_weights = np.array([p_weight, m_weight, r_weight])
        self.module_attention_weights = alpha * self.module_attention_weights + (1-alpha) * new_weights
        
        # Normalize to ensure they sum to 1
        self.module_attention_weights = self.module_attention_weights / np.sum(self.module_attention_weights)
        
        return self.module_attention_weights


#dnon_modules_squad_Without_LTM.py
import json
import time
import logging
import numpy as np
import boto3
import torch
from scipy.special import softmax
from manifold_representation_StrategyQA import RiemannianManifold, EuclideanEmbedding
from memory_dynamics_SQuAD_Without_LTM import ShortTermMemory, LongTermMemory, DeepSubconsciousMemory
from information_geometry_StrategyQA import FisherInformationMatrix, NaturalGradient

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BaseAgent:
    def __init__(self, state_dim=256, temperature=0.02):
        self.MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
        self.temperature = temperature
        self.bedrock_runtime = None
        self.latency_stats = []
        
        # Add manifold representation
        self.manifold = RiemannianManifold(dim=state_dim)
        self.state = None
        self.state_dim = state_dim
        self.state_history = []
        
        # Add Fisher information matrix for natural gradient
        self.fisher_matrix = FisherInformationMatrix(dim=state_dim)
        
    def set_bedrock_client(self, bedrock_client):
        self.bedrock_runtime = bedrock_client

    def query(self, user_input):
        """Query the Claude model with the given input"""
        logger.info(f"Query sent to {self.__class__.__name__}: {user_input[:50]}...")
        
        # Start timing
        start_time = time.time()

        try:
            response = self.bedrock_runtime.invoke_model(
                modelId=self.MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 4090,
                    "messages": [{"role": "user", "content": [
                        {"type": "text", "text": f"{self.system_prompt}\n\n{user_input}"}
                    ]}],
                    "temperature": self.temperature,
                    "top_p": 1,
                })
            )
            
            response_body = json.loads(response['body'].read())
            response_text = response_body['content'][0]['text']
            
            # Calculate and record latency
            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            self.latency_stats.append(latency_ms)
            
            # Update state manifold with embedding of response
            embedding = self.generate_embedding(response_text)
            self.update_state(embedding)
            
            logger.info(f"Response from {self.__class__.__name__}: {response_text[:100]}... ({latency_ms:.2f}ms)")
            return response_text

        except Exception as e:
            logger.error(f"Error querying model: {e}")
            return f"An error occurred: {e}"

    def generate_embedding(self, text):
        """Generate embedding using Cohere model, handling large text inputs"""
        try:
            # Check if text is too long (using Cohere's 2048 limit)
            if len(text) > 2000:  # Set to 2000 to be safe
                # Split text into chunks of ~2000 chars
                chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
                logger.info(f"Text length {len(text)} chars, split into {len(chunks)} chunks")
                
                # Get embeddings for each chunk (limit to first 5 chunks)
                chunk_embeddings = []
                for i, chunk in enumerate(chunks[:5]):  # Only use first 5 chunks
                    try:
                        response = self.bedrock_runtime.invoke_model(
                            modelId="cohere.embed-multilingual-v3",
                            contentType="application/json",
                            body=json.dumps({"texts": [chunk], "input_type": "search_document"})
                        )
                        chunk_embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
                        chunk_embeddings.append(chunk_embedding)
                    except Exception as chunk_e:
                        logger.error(f"Error embedding chunk {i}: {chunk_e}")
                
                # Average the chunk embeddings
                if chunk_embeddings:
                    embedding = np.mean(chunk_embeddings, axis=0)
                else:
                    # Fallback to random if all chunks failed
                    embedding = np.random.rand(1024)
            else:
                # Standard case for shorter text
                response = self.bedrock_runtime.invoke_model(
                    modelId="cohere.embed-multilingual-v3",
                    contentType="application/json",
                    body=json.dumps({"texts": [text], "input_type": "search_document"})
                )
                embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
            
            # Handle dimension mismatch through chunking/processing
            if len(embedding) != self.state_dim:
                # Process in chunks to create a compressed representation
                chunks = len(embedding) // self.state_dim
                compressed_embedding = np.zeros(self.state_dim)
                
                for i in range(self.state_dim):
                    start_idx = i * chunks
                    end_idx = min(start_idx + chunks, len(embedding))
                    compressed_embedding[i] = np.mean(embedding[start_idx:end_idx])
                    
                return compressed_embedding
            
            return embedding
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            return np.random.rand(self.state_dim)
    
    def update_state(self, embedding):
        """Update agent state on the manifold"""
        if self.state is None:
            self.state = EuclideanEmbedding(embedding)
        else:
            # Project the embedding onto the manifold
            manifold_point = self.manifold.project(embedding)
            
            # Update state using information geometry
            self.state = self.manifold.geodesic(self.state, manifold_point, step_size=0.5)
        
        # Store state history for convergence analysis
        self.state_history.append(self.state.copy())
    
    def get_state(self):
        """Get current state vector"""
        if self.state is None:
            return np.zeros(self.state_dim)
        return self.state.coordinates

    def compute_fisher_information(self):
        """Compute Fisher information matrix based on state history"""
        if len(self.state_history) > 1:
            self.fisher_matrix.update([sh.coordinates for sh in self.state_history])
        return self.fisher_matrix.get_matrix()


class PerceptionAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized perception module for reading comprehension on the SQuAD dataset. Your task is to:

1. Parse the given context and question with extreme precision
2. Extract ALL key entities, concepts, and their relationships from both the context and question
3. Identify the type of information being requested:
   - FACTUAL DETAILS: specific facts mentioned in the context
   - ENTITY INFORMATION: details about people, places, organizations
   - TEMPORAL RELATIONSHIPS: when events occurred
   - CAUSAL RELATIONSHIPS: why something happened
   - DEFINITIONS: what something means or refers to

4. Structure the information to make answer extraction easier by:
   - Highlighting sentences that likely contain the answer
   - Noting connections between the question entities and context
   - Identifying potential answer spans in the context

Output a structured representation of the context and question that will aid in answering the question accurately.
DO NOT answer the question yourself. Focus ONLY on structuring it for precise reasoning.
"""


class MemoryAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02, 
                st_decay_rate=0.1, lt_consolidation_rate=0.01, ds_access_coef=0.001):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized memory module for reading comprehension on the SQuAD dataset. Your job is to:

1. Provide ALL relevant knowledge needed to answer the question accurately
2. Focus primarily on information from the provided context
3. Supplement with background knowledge only when necessary for clarification
4. For named entities in the context, provide relevant facts about them
5. For specialized terminology, provide brief definitions if they appear in the context
6. Identify connections between concepts in the question and those in the context
7. Highlight the most relevant passages that likely contain the answer

Present this information in clear, factual statements organized by topic.
DO NOT solve the problem or provide the final answer - only provide the relevant knowledge.
"""
        
        # Enhanced memory implementation with only STM and DSM, no LTM
        self.stm = ShortTermMemory(capacity=20, decay_rate=st_decay_rate, dim=state_dim)
        # Create a stub LTM that does nothing
        self.ltm = LongTermMemory(capacity=100, consolidation_rate=lt_consolidation_rate, dim=state_dim)
        self.dsm = DeepSubconsciousMemory(access_coef=ds_access_coef, dim=state_dim)
        
    def process_with_memory(self, query, facts=None):
        """Process query using STM and DSM without LTM"""
        # 1. Generate embedding for the query
        query_embedding = self.generate_embedding(query)
        
        # 2. Process with memory components
        stm_output = self.stm.retrieve(query_embedding)
        # No LTM retrieval - use empty string instead
        ltm_output = ""
        dsm_output = self.dsm.retrieve(query_embedding)
        
        # 3. No LTM consolidation since we're not using LTM
        
        # 4. Update deep subconscious memory
        # Use empty string for ltm_output since we're not using LTM
        emotional_salience = self.calculate_emotional_salience(query)
        self.dsm.update(query_embedding, "", emotional_salience)
        
        # 5. Update short-term memory
        self.stm.add(query, query_embedding)
        
        # Build memory context without LTM
        memory_context = f"""
        SHORT-TERM MEMORY (recent questions):
        {stm_output}

        DEEP KNOWLEDGE (reading comprehension rules):
        {dsm_output}
        """
        
        # If facts are provided, include them in the context
        if facts and len(facts) > 0:
            facts_text = "\n".join([f"- {fact}" for fact in facts])
            memory_context += f"\n\nRELEVANT FACTS:\n{facts_text}"
        
        # Query the LLM with enhanced context
        enhanced_query = f"{query}\n\nUse the following knowledge:\n{memory_context}"
        return self.query(enhanced_query), memory_context
    
    def calculate_importance(self, query):
        """Calculate importance score for memory consolidation"""
        # Simple heuristic based on query complexity
        importance = min(0.9, 0.3 + 0.1 * len(query.split()))
        
        # Check for keywords indicating importance
        important_keywords = ["define", "who", "when", "where", "why", "how", "explain"]
        if any(keyword in query.lower() for keyword in important_keywords):
            importance += 0.2
            
        return min(1.0, importance)
    
    def calculate_emotional_salience(self, query):
        """Calculate emotional salience for deep memory updates"""
        # Simple heuristic - in a real system this would use sentiment analysis
        emotional_words = ["important", "significant", "critical", "crucial", 
                         "key", "essential", "primary", "main", "fundamental"]
        
        salience = 0.1  # Base salience
        for word in emotional_words:
            if word in query.lower():
                salience += 0.1
                
        return min(1.0, salience)


class ReasoningAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized reasoning module for reading comprehension on the SQuAD dataset. Your job is to:

1. Process the structured question and relevant knowledge with extreme precision
2. Apply methodical, step-by-step logical reasoning to extract the answer from the context
3. Focus on the specific information requested in the question
4. Locate the precise span of text in the context that answers the question
5. Verify that your answer is directly supported by the context
6. Ensure your answer is as concise as possible while fully answering the question

Follow this exact reasoning structure:
Step 1: Identify the type of information being requested
Step 2: Locate relevant passages in the context
Step 3: Extract candidate answer spans
Step 4: Verify each candidate against the question
Step 5: Select the most precise answer span

Final Answer: [Extracted answer span from the context]

Be extremely precise. Your answer should be the exact text span from the context that answers the question.
"""


class ExecutiveAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.01):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are an executive decision-making module that integrates outputs from three specialized modules for SQuAD questions:

1. Perception Module: How the question and context are structured
2. Memory Module: Relevant knowledge and facts
3. Reasoning Module: Step-by-step logical reasoning

Given the module weights, make the final determination by following this process:

1. Carefully review the context and question
2. Examine the reasoning process used to extract the answer
3. Verify that the answer is directly supported by the context
4. Ensure the answer is the most precise span of text that fully answers the question
5. If multiple answer candidates exist, select the one with the strongest support

Return ONLY the answer text, with no additional explanation. Your response should be the exact span of text from the context that answers the question.
"""
        
        # Initialize attention weights for modules
        self.module_attention_weights = np.array([0.33, 0.33, 0.33])  # P, M, R
        
    def update_attention_weights(self, routing_weights):
        """Update attention weights based on routing matrix"""
        # Extract relevant weights from routing matrix
        p_weight = routing_weights['perception']
        m_weight = routing_weights['memory']
        r_weight = routing_weights['reasoning']
        
        # Update module attention weights with exponential moving average
        alpha = 0.7  # EMA factor
        new_weights = np.array([p_weight, m_weight, r_weight])
        self.module_attention_weights = alpha * self.module_attention_weights + (1-alpha) * new_weights
        
        # Normalize to ensure they sum to 1
        self.module_attention_weights = self.module_attention_weights / np.sum(self.module_attention_weights)
        
        return self.module_attention_weights

#dnon_modules_squad_Without_Memory.py
import json
import time
import logging
import numpy as np
import boto3
import torch
from scipy.special import softmax
from manifold_representation_StrategyQA import RiemannianManifold, EuclideanEmbedding
from memory_dynamics_SQuAD_Without_Memory import ShortTermMemory, LongTermMemory, DeepSubconsciousMemory
from information_geometry_StrategyQA import FisherInformationMatrix, NaturalGradient

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BaseAgent:
    def __init__(self, state_dim=256, temperature=0.02):
        self.MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
        self.temperature = temperature
        self.bedrock_runtime = None
        self.latency_stats = []
        
        # Add manifold representation
        self.manifold = RiemannianManifold(dim=state_dim)
        self.state = None
        self.state_dim = state_dim
        self.state_history = []
        
        # Add Fisher information matrix for natural gradient
        self.fisher_matrix = FisherInformationMatrix(dim=state_dim)
        
    def set_bedrock_client(self, bedrock_client):
        self.bedrock_runtime = bedrock_client

    def query(self, user_input):
        """Query the Claude model with the given input"""
        logger.info(f"Query sent to {self.__class__.__name__}: {user_input[:50]}...")
        
        # Start timing
        start_time = time.time()

        try:
            response = self.bedrock_runtime.invoke_model(
                modelId=self.MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 4090,
                    "messages": [{"role": "user", "content": [
                        {"type": "text", "text": f"{self.system_prompt}\n\n{user_input}"}
                    ]}],
                    "temperature": self.temperature,
                    "top_p": 1,
                })
            )
            
            response_body = json.loads(response['body'].read())
            response_text = response_body['content'][0]['text']
            
            # Calculate and record latency
            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            self.latency_stats.append(latency_ms)
            
            # Update state manifold with embedding of response
            embedding = self.generate_embedding(response_text)
            self.update_state(embedding)
            
            logger.info(f"Response from {self.__class__.__name__}: {response_text[:100]}... ({latency_ms:.2f}ms)")
            return response_text

        except Exception as e:
            logger.error(f"Error querying model: {e}")
            return f"An error occurred: {e}"

    def generate_embedding(self, text):
        """Generate embedding using Cohere model, handling large text inputs"""
        try:
            # Check if text is too long (using Cohere's 2048 limit)
            if len(text) > 2000:  # Set to 2000 to be safe
                # Split text into chunks of ~2000 chars
                chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
                logger.info(f"Text length {len(text)} chars, split into {len(chunks)} chunks")
                
                # Get embeddings for each chunk (limit to first 5 chunks)
                chunk_embeddings = []
                for i, chunk in enumerate(chunks[:5]):  # Only use first 5 chunks
                    try:
                        response = self.bedrock_runtime.invoke_model(
                            modelId="cohere.embed-multilingual-v3",
                            contentType="application/json",
                            body=json.dumps({"texts": [chunk], "input_type": "search_document"})
                        )
                        chunk_embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
                        chunk_embeddings.append(chunk_embedding)
                    except Exception as chunk_e:
                        logger.error(f"Error embedding chunk {i}: {chunk_e}")
                
                # Average the chunk embeddings
                if chunk_embeddings:
                    embedding = np.mean(chunk_embeddings, axis=0)
                else:
                    # Fallback to random if all chunks failed
                    embedding = np.random.rand(1024)
            else:
                # Standard case for shorter text
                response = self.bedrock_runtime.invoke_model(
                    modelId="cohere.embed-multilingual-v3",
                    contentType="application/json",
                    body=json.dumps({"texts": [text], "input_type": "search_document"})
                )
                embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
            
            # Handle dimension mismatch through chunking/processing
            if len(embedding) != self.state_dim:
                # Process in chunks to create a compressed representation
                chunks = len(embedding) // self.state_dim
                compressed_embedding = np.zeros(self.state_dim)
                
                for i in range(self.state_dim):
                    start_idx = i * chunks
                    end_idx = min(start_idx + chunks, len(embedding))
                    compressed_embedding[i] = np.mean(embedding[start_idx:end_idx])
                    
                return compressed_embedding
            
            return embedding
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            return np.random.rand(self.state_dim)
    
    def update_state(self, embedding):
        """Update agent state on the manifold"""
        if self.state is None:
            self.state = EuclideanEmbedding(embedding)
        else:
            # Project the embedding onto the manifold
            manifold_point = self.manifold.project(embedding)
            
            # Update state using information geometry
            self.state = self.manifold.geodesic(self.state, manifold_point, step_size=0.5)
        
        # Store state history for convergence analysis
        self.state_history.append(self.state.copy())
    
    def get_state(self):
        """Get current state vector"""
        if self.state is None:
            return np.zeros(self.state_dim)
        return self.state.coordinates

    def compute_fisher_information(self):
        """Compute Fisher information matrix based on state history"""
        if len(self.state_history) > 1:
            self.fisher_matrix.update([sh.coordinates for sh in self.state_history])
        return self.fisher_matrix.get_matrix()


class PerceptionAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized perception module for reading comprehension on the SQuAD dataset. Your task is to:

1. Parse the given context and question with extreme precision
2. Extract ALL key entities, concepts, and their relationships from both the context and question
3. Identify the type of information being requested:
   - FACTUAL DETAILS: specific facts mentioned in the context
   - ENTITY INFORMATION: details about people, places, organizations
   - TEMPORAL RELATIONSHIPS: when events occurred
   - CAUSAL RELATIONSHIPS: why something happened
   - DEFINITIONS: what something means or refers to

4. Structure the information to make answer extraction easier by:
   - Highlighting sentences that likely contain the answer
   - Noting connections between the question entities and context
   - Identifying potential answer spans in the context

Output a structured representation of the context and question that will aid in answering the question accurately.
DO NOT answer the question yourself. Focus ONLY on structuring it for precise reasoning.
"""


class MemoryAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02, 
                st_decay_rate=0.1, lt_consolidation_rate=0.01, ds_access_coef=0.001):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized module for reading comprehension on the SQuAD dataset. Your job is to:

1. Provide relevant information needed to answer the question accurately
2. Focus primarily on information from the provided context
3. Identify connections between concepts in the question and those in the context
4. Highlight the passages that likely contain the answer

DO NOT solve the problem or provide the final answer - only structure the information.
"""
        
        # Initialize stub memory components that don't actually store or retrieve anything
        self.stm = ShortTermMemory(capacity=20, decay_rate=st_decay_rate, dim=state_dim)
        self.ltm = LongTermMemory(capacity=100, consolidation_rate=lt_consolidation_rate, dim=state_dim)
        self.dsm = DeepSubconsciousMemory(access_coef=ds_access_coef, dim=state_dim)
        
    def process_with_memory(self, query, facts=None):
        """Process query without using any memory components"""
        # Generate embedding for record-keeping only
        query_embedding = self.generate_embedding(query)
        
        # No memory processing - just pass the query directly to the LLM
        # Return empty memory context for consistency with other versions
        memory_context = ""
        
        # Directly query the LLM without memory context
        return self.query(query), memory_context
    
    def calculate_importance(self, query):
        """Stub method for consistency with interface"""
        return 0.5
    
    def calculate_emotional_salience(self, query):
        """Stub method for consistency with interface"""
        return 0.1


class ReasoningAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized reasoning module for reading comprehension on the SQuAD dataset. Your job is to:

1. Process the structured question and relevant knowledge with extreme precision
2. Apply methodical, step-by-step logical reasoning to extract the answer from the context
3. Focus on the specific information requested in the question
4. Locate the precise span of text in the context that answers the question
5. Verify that your answer is directly supported by the context
6. Ensure your answer is as concise as possible while fully answering the question

Follow this exact reasoning structure:
Step 1: Identify the type of information being requested
Step 2: Locate relevant passages in the context
Step 3: Extract candidate answer spans
Step 4: Verify each candidate against the question
Step 5: Select the most precise answer span

Final Answer: [Extracted answer span from the context]

Be extremely precise. Your answer should be the exact text span from the context that answers the question.
"""


class ExecutiveAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.01):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are an executive decision-making module that integrates outputs from three specialized modules for SQuAD questions:

1. Perception Module: How the question and context are structured
2. Memory Module: Relevant knowledge and facts
3. Reasoning Module: Step-by-step logical reasoning

Given the module weights, make the final determination by following this process:

1. Carefully review the context and question
2. Examine the reasoning process used to extract the answer
3. Verify that the answer is directly supported by the context
4. Ensure the answer is the most precise span of text that fully answers the question
5. If multiple answer candidates exist, select the one with the strongest support

Return ONLY the answer text, with no additional explanation. Your response should be the exact span of text from the context that answers the question.
"""
        
        # Initialize attention weights for modules
        self.module_attention_weights = np.array([0.33, 0.33, 0.33])  # P, M, R
        
    def update_attention_weights(self, routing_weights):
        """Update attention weights based on routing matrix"""
        # Extract relevant weights from routing matrix
        p_weight = routing_weights['perception']
        m_weight = routing_weights['memory']
        r_weight = routing_weights['reasoning']
        
        # Update module attention weights with exponential moving average
        alpha = 0.7  # EMA factor
        new_weights = np.array([p_weight, m_weight, r_weight])
        self.module_attention_weights = alpha * self.module_attention_weights + (1-alpha) * new_weights
        
        # Normalize to ensure they sum to 1
        self.module_attention_weights = self.module_attention_weights / np.sum(self.module_attention_weights)
        
        return self.module_attention_weights

#dnon_modules_squad_Without_STM.py
import json
import time
import logging
import numpy as np
import boto3
import torch
from scipy.special import softmax
from manifold_representation_StrategyQA import RiemannianManifold, EuclideanEmbedding
from memory_dynamics_SQuAD_Without_STM import ShortTermMemory, LongTermMemory, DeepSubconsciousMemory
from information_geometry_StrategyQA import FisherInformationMatrix, NaturalGradient

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BaseAgent:
    def __init__(self, state_dim=256, temperature=0.02):
        self.MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
        self.temperature = temperature
        self.bedrock_runtime = None
        self.latency_stats = []
        
        # Add manifold representation
        self.manifold = RiemannianManifold(dim=state_dim)
        self.state = None
        self.state_dim = state_dim
        self.state_history = []
        
        # Add Fisher information matrix for natural gradient
        self.fisher_matrix = FisherInformationMatrix(dim=state_dim)
        
    def set_bedrock_client(self, bedrock_client):
        self.bedrock_runtime = bedrock_client

    def query(self, user_input):
        """Query the Claude model with the given input"""
        logger.info(f"Query sent to {self.__class__.__name__}: {user_input[:50]}...")
        
        # Start timing
        start_time = time.time()

        try:
            response = self.bedrock_runtime.invoke_model(
                modelId=self.MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 4090,
                    "messages": [{"role": "user", "content": [
                        {"type": "text", "text": f"{self.system_prompt}\n\n{user_input}"}
                    ]}],
                    "temperature": self.temperature,
                    "top_p": 1,
                })
            )
            
            response_body = json.loads(response['body'].read())
            response_text = response_body['content'][0]['text']
            
            # Calculate and record latency
            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            self.latency_stats.append(latency_ms)
            
            # Update state manifold with embedding of response
            embedding = self.generate_embedding(response_text)
            self.update_state(embedding)
            
            logger.info(f"Response from {self.__class__.__name__}: {response_text[:100]}... ({latency_ms:.2f}ms)")
            return response_text

        except Exception as e:
            logger.error(f"Error querying model: {e}")
            return f"An error occurred: {e}"

    def generate_embedding(self, text):
        """Generate embedding using Cohere model, handling large text inputs"""
        try:
            # Check if text is too long (using Cohere's 2048 limit)
            if len(text) > 2000:  # Set to 2000 to be safe
                # Split text into chunks of ~2000 chars
                chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
                logger.info(f"Text length {len(text)} chars, split into {len(chunks)} chunks")
                
                # Get embeddings for each chunk (limit to first 5 chunks)
                chunk_embeddings = []
                for i, chunk in enumerate(chunks[:5]):  # Only use first 5 chunks
                    try:
                        response = self.bedrock_runtime.invoke_model(
                            modelId="cohere.embed-multilingual-v3",
                            contentType="application/json",
                            body=json.dumps({"texts": [chunk], "input_type": "search_document"})
                        )
                        chunk_embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
                        chunk_embeddings.append(chunk_embedding)
                    except Exception as chunk_e:
                        logger.error(f"Error embedding chunk {i}: {chunk_e}")
                
                # Average the chunk embeddings
                if chunk_embeddings:
                    embedding = np.mean(chunk_embeddings, axis=0)
                else:
                    # Fallback to random if all chunks failed
                    embedding = np.random.rand(1024)
            else:
                # Standard case for shorter text
                response = self.bedrock_runtime.invoke_model(
                    modelId="cohere.embed-multilingual-v3",
                    contentType="application/json",
                    body=json.dumps({"texts": [text], "input_type": "search_document"})
                )
                embedding = np.array(json.loads(response['body'].read())['embeddings'][0])
            
            # Handle dimension mismatch through chunking/processing
            if len(embedding) != self.state_dim:
                # Process in chunks to create a compressed representation
                chunks = len(embedding) // self.state_dim
                compressed_embedding = np.zeros(self.state_dim)
                
                for i in range(self.state_dim):
                    start_idx = i * chunks
                    end_idx = min(start_idx + chunks, len(embedding))
                    compressed_embedding[i] = np.mean(embedding[start_idx:end_idx])
                    
                return compressed_embedding
            
            return embedding
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            return np.random.rand(self.state_dim)
    
    def update_state(self, embedding):
        """Update agent state on the manifold"""
        if self.state is None:
            self.state = EuclideanEmbedding(embedding)
        else:
            # Project the embedding onto the manifold
            manifold_point = self.manifold.project(embedding)
            
            # Update state using information geometry
            self.state = self.manifold.geodesic(self.state, manifold_point, step_size=0.5)
        
        # Store state history for convergence analysis
        self.state_history.append(self.state.copy())
    
    def get_state(self):
        """Get current state vector"""
        if self.state is None:
            return np.zeros(self.state_dim)
        return self.state.coordinates

    def compute_fisher_information(self):
        """Compute Fisher information matrix based on state history"""
        if len(self.state_history) > 1:
            self.fisher_matrix.update([sh.coordinates for sh in self.state_history])
        return self.fisher_matrix.get_matrix()


class PerceptionAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized perception module for reading comprehension on the SQuAD dataset. Your task is to:

1. Parse the given context and question with extreme precision
2. Extract ALL key entities, concepts, and their relationships from both the context and question
3. Identify the type of information being requested:
   - FACTUAL DETAILS: specific facts mentioned in the context
   - ENTITY INFORMATION: details about people, places, organizations
   - TEMPORAL RELATIONSHIPS: when events occurred
   - CAUSAL RELATIONSHIPS: why something happened
   - DEFINITIONS: what something means or refers to

4. Structure the information to make answer extraction easier by:
   - Highlighting sentences that likely contain the answer
   - Noting connections between the question entities and context
   - Identifying potential answer spans in the context

Output a structured representation of the context and question that will aid in answering the question accurately.
DO NOT answer the question yourself. Focus ONLY on structuring it for precise reasoning.
"""


class MemoryAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02, 
                st_decay_rate=0.1, lt_consolidation_rate=0.01, ds_access_coef=0.001):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized memory module for reading comprehension on the SQuAD dataset. Your job is to:

1. Provide ALL relevant knowledge needed to answer the question accurately
2. Focus primarily on information from the provided context
3. Supplement with background knowledge only when necessary for clarification
4. For named entities in the context, provide relevant facts about them
5. For specialized terminology, provide brief definitions if they appear in the context
6. Identify connections between concepts in the question and those in the context
7. Highlight the most relevant passages that likely contain the answer

Present this information in clear, factual statements organized by topic.
DO NOT solve the problem or provide the final answer - only provide the relevant knowledge.
"""
        
        # Enhanced memory implementation with LTM and DSM, but STM is disabled
        self.stm = ShortTermMemory(capacity=20, decay_rate=st_decay_rate, dim=state_dim)
        self.ltm = LongTermMemory(capacity=100, consolidation_rate=lt_consolidation_rate, dim=state_dim)
        self.dsm = DeepSubconsciousMemory(access_coef=ds_access_coef, dim=state_dim)
        
    def process_with_memory(self, query, facts=None):
        """Process query using LTM and DSM without STM"""
        # 1. Generate embedding for the query
        query_embedding = self.generate_embedding(query)
        
        # 2. Process with memory components
        # No STM retrieval - use empty string
        stm_output = ""
        # Get LTM and DSM outputs
        ltm_output = self.ltm.retrieve(query_embedding)
        dsm_output = self.dsm.retrieve(query_embedding)
        
        # 3. Skip STM but update LTM directly
        importance_score = self.calculate_importance(query)
        self.ltm.consolidate(self.stm, query_embedding, importance_score)
        
        # 4. Update deep subconscious memory
        emotional_salience = self.calculate_emotional_salience(query)
        self.dsm.update(query_embedding, ltm_output, emotional_salience)
        
        # 5. Build memory context without STM
        memory_context = f"""
        LONG-TERM MEMORY (consolidated knowledge):
        {ltm_output}

        DEEP KNOWLEDGE (reading comprehension rules):
        {dsm_output}
        """
        
        # If facts are provided, include them in the context
        if facts and len(facts) > 0:
            facts_text = "\n".join([f"- {fact}" for fact in facts])
            memory_context += f"\n\nRELEVANT FACTS:\n{facts_text}"
        
        # Query the LLM with enhanced context
        enhanced_query = f"{query}\n\nUse the following knowledge:\n{memory_context}"
        return self.query(enhanced_query), memory_context
    
    def calculate_importance(self, query):
        """Calculate importance score for memory consolidation"""
        # Simple heuristic based on query complexity
        importance = min(0.9, 0.3 + 0.1 * len(query.split()))
        
        # Check for keywords indicating importance
        important_keywords = ["define", "who", "when", "where", "why", "how", "explain"]
        if any(keyword in query.lower() for keyword in important_keywords):
            importance += 0.2
            
        return min(1.0, importance)
    
    def calculate_emotional_salience(self, query):
        """Calculate emotional salience for deep memory updates"""
        # Simple heuristic - in a real system this would use sentiment analysis
        emotional_words = ["important", "significant", "critical", "crucial", 
                         "key", "essential", "primary", "main", "fundamental"]
        
        salience = 0.1  # Base salience
        for word in emotional_words:
            if word in query.lower():
                salience += 0.1
                
        return min(1.0, salience)


class ReasoningAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.02):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are a specialized reasoning module for reading comprehension on the SQuAD dataset. Your job is to:

1. Process the structured question and relevant knowledge with extreme precision
2. Apply methodical, step-by-step logical reasoning to extract the answer from the context
3. Focus on the specific information requested in the question
4. Locate the precise span of text in the context that answers the question
5. Verify that your answer is directly supported by the context
6. Ensure your answer is as concise as possible while fully answering the question

Follow this exact reasoning structure:
Step 1: Identify the type of information being requested
Step 2: Locate relevant passages in the context
Step 3: Extract candidate answer spans
Step 4: Verify each candidate against the question
Step 5: Select the most precise answer span

Final Answer: [Extracted answer span from the context]

Be extremely precise. Your answer should be the exact text span from the context that answers the question.
"""


class ExecutiveAgent(BaseAgent):
    def __init__(self, state_dim=256, temperature=0.01):
        super().__init__(state_dim, temperature)
        self.system_prompt = """You are an executive decision-making module that integrates outputs from three specialized modules for SQuAD questions:

1. Perception Module: How the question and context are structured
2. Memory Module: Relevant knowledge and facts
3. Reasoning Module: Step-by-step logical reasoning

Given the module weights, make the final determination by following this process:

1. Carefully review the context and question
2. Examine the reasoning process used to extract the answer
3. Verify that the answer is directly supported by the context
4. Ensure the answer is the most precise span of text that fully answers the question
5. If multiple answer candidates exist, select the one with the strongest support

Return ONLY the answer text, with no additional explanation. Your response should be the exact span of text from the context that answers the question.
"""
        
        # Initialize attention weights for modules
        self.module_attention_weights = np.array([0.33, 0.33, 0.33])  # P, M, R
        
    def update_attention_weights(self, routing_weights):
        """Update attention weights based on routing matrix"""
        # Extract relevant weights from routing matrix
        p_weight = routing_weights['perception']
        m_weight = routing_weights['memory']
        r_weight = routing_weights['reasoning']
        
        # Update module attention weights with exponential moving average
        alpha = 0.7  # EMA factor
        new_weights = np.array([p_weight, m_weight, r_weight])
        self.module_attention_weights = alpha * self.module_attention_weights + (1-alpha) * new_weights
        
        # Normalize to ensure they sum to 1
        self.module_attention_weights = self.module_attention_weights / np.sum(self.module_attention_weights)
        
        return self.module_attention_weights