import numpy as np
import time
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Union, Optional
from pathlib import Path

from utils import load_hidden_activities

def getDimensionality_SVD(activity_matrix, center_data=True):
    """
    Calculate participation ratio dimensionality using SVD on activity matrix.
    More efficient and numerically stable than eigendecomposition on similarity matrix.
    
    Args:
        activity_matrix: 2D array of shape (n_conditions, n_units)
        center_data: Whether to center the data before SVD (recommended: True)
    
    Returns:
        dimensionality: Participation ratio value
    """
    # Center the data if requested
    if center_data:
        activity_centered = activity_matrix - np.mean(activity_matrix, axis=0)
    else:
        activity_centered = activity_matrix
    
    # Compute SVD
    U, s, Vt = np.linalg.svd(activity_centered, full_matrices=False)
    
    # Filter out near-zero singular values for numerical stability
    s_filtered = s[s > 1e-12]
    
    # Compute participation ratio
    if len(s_filtered) == 0:
        return 0.0
    
    dimensionality = (np.sum(s_filtered))**2 / np.sum(s_filtered**2)
    return dimensionality

def getDimensionality(data):
    """
    LEGACY: Calculate participation ratio dimensionality using eigendecomposition.
    Kept for backward compatibility. Consider using getDimensionality_SVD instead.
    
    Args:
        data: Square matrix to compute dimensionality for
    
    Returns:
        dimensionality: Participation ratio value
    """
    corrmat = data
    
    eigenvalues = np.linalg.eigvals(corrmat)
    dimensionality_nom = 0
    dimensionality_denom = 0
    
    for eig in eigenvalues:
        dimensionality_nom += np.real(eig)
        dimensionality_denom += np.real(eig)**2
    dimensionality = dimensionality_nom**2/dimensionality_denom
    return dimensionality

def get_task_dimensionality(hidden_file, trial_length=10, collapse_only=False, integration_only=False):
    """
    Calculate task dimensionality by averaging across stimulus dimension and applying SVD.
    For each timepoint, processes the 64×256×128 activity to get task representations,
    then computes participation ratio using SVD.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
        collapse_only: If True, only use collapse trials
        integration_only: If True, only use integration trials
    
    Returns:
        task_dim: Array of dimensionality values across timepoints
    """
    from cpro import CPRO
    env = CPRO(training_mode="minimal")
    
    hidden_activities = load_hidden_activities(hidden_file)
    task_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get stimulus-level activities: (64*256, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        
        if collapse_only or integration_only:
            # Filter by individual (task, stimulus) combinations
            stimulus_info = hidden_activities['stimulus_info'][timepoint]
            
            # Create mask for which trials to include
            trial_mask = np.zeros(len(stimulus_info), dtype=bool)
            
            for trial_idx, trial_info in enumerate(stimulus_info):
                task = trial_info['task']
                stimulus = trial_info['stimulus']['stimulus']
                
                is_collapse = can_decide_after_stim1(task, stimulus)
                
                if collapse_only and is_collapse:
                    trial_mask[trial_idx] = True
                elif integration_only and not is_collapse:
                    trial_mask[trial_idx] = True
            
            if trial_mask.sum() < 2:  # Need at least 2 trials
                task_dim[timepoint] = 0.0
                continue
            
            # Group filtered trials by task and average within each task
            task_averages = []
            
            # Get unique tasks from filtered trials
            filtered_tasks = []
            for trial_idx in np.where(trial_mask)[0]:
                trial_info = stimulus_info[trial_idx]
                task = trial_info['task']
                filtered_tasks.append(task)
            
            # Convert to hashable format for grouping
            unique_tasks = []
            seen_tasks = set()
            for task in filtered_tasks:
                task_key = (task['logical_ctx'], task['sensory_ctx'], task['motor_ctx'])
                if task_key not in seen_tasks:
                    unique_tasks.append(task)
                    seen_tasks.add(task_key)
            
            # For each unique task, average across its filtered trials
            for unique_task in unique_tasks:
                # Find all filtered trials for this task
                matching_trials = []
                for trial_idx in np.where(trial_mask)[0]:
                    trial_info = stimulus_info[trial_idx]
                    task = trial_info['task']
                    if (task['logical_ctx'] == unique_task['logical_ctx'] and 
                        task['sensory_ctx'] == unique_task['sensory_ctx'] and 
                        task['motor_ctx'] == unique_task['motor_ctx']):
                        matching_trials.append(trial_idx)
                
                if len(matching_trials) > 0:
                    # Average across all filtered stimuli for this task
                    avg_activity = np.mean(timepoint_activities[matching_trials], axis=0)
                    task_averages.append(avg_activity)
            
            if len(task_averages) < 2:
                task_dim[timepoint] = 0.0
                continue
                
            task_activities = np.array(task_averages)  # Shape: (n_tasks_with_trials, 128)
            
        else:
            # Original logic: reshape and average across all stimuli
            activities_reshaped = timepoint_activities.reshape(64, 256, 128)
            task_activities = np.mean(activities_reshaped, axis=1)  # Shape: (64, 128)
        
        # Compute dimensionality using SVD
        task_dim[timepoint] = getDimensionality_SVD(task_activities)
        
    return task_dim

### DIMENSIONALITY BY TASK TYPE (EACH RULE DOMAIN) ###

def group_tasks_by_rule_domain(hidden_activities, timepoint, rule_domain):
    """
    Group task indices by a specific rule domain for averaging.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        timepoint: Which timepoint to use for task info (default: 0, since task info is same across timepoints)
        rule_domain: 'logical', 'sensory', or 'motor'
    
    Returns:
        Dictionary mapping rule_value (0-3) to list of task indices that have that rule value
    """
    # Get task info for the specified timepoint
    task_info = hidden_activities['task_info'][timepoint]
    
    # Map rule domain to the correct key in task dictionary
    domain_key_map = {
        'logical': 'logical_ctx',
        'sensory': 'sensory_ctx', 
        'motor': 'motor_ctx'
    }
    
    if rule_domain not in domain_key_map:
        raise ValueError(f"rule_domain must be one of {list(domain_key_map.keys())}, got {rule_domain}")
    
    domain_key = domain_key_map[rule_domain]
    
    # Group tasks by rule value
    rule_groups = {0: [], 1: [], 2: [], 3: []}
    
    for task_idx, task in enumerate(task_info):
        rule_value = task[domain_key]
        rule_groups[rule_value].append(task_idx)
    
    # Verify we have tasks for each rule value
    for rule_value in range(4):
        if len(rule_groups[rule_value]) == 0:
            print(f"Warning: No tasks found for {rule_domain} rule {rule_value}")
    
    return rule_groups

def get_logical_rule_dimensionality(hidden_file, trial_length=10):
    """
    Calculate logical rule dimensionality by averaging across sensory and motor rules.
    For each timepoint, groups the 64 tasks by logical rule (0-3), averages within each group
    to get (4, 128) matrix, then computes participation ratio using SVD.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
    
    Returns:
        logical_rule_dim: Array of dimensionality values across timepoints
    """
    hidden_activities = load_hidden_activities(hidden_file)
    logical_rule_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get task-level activities: average across stimuli to get (64, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        activities_reshaped = timepoint_activities.reshape(64, 256, 128)
        task_activities = np.mean(activities_reshaped, axis=1)  # Shape: (64, 128)
        
        # Group tasks by logical rule
        logical_groups = group_tasks_by_rule_domain(hidden_activities, timepoint, 'logical')
        
        # Average within each logical rule group
        logical_rule_averages = []
        
        for rule_value in range(4):  # 4 logical rules (AND, NAND, OR, NOR)
            task_indices = logical_groups[rule_value]
            
            if len(task_indices) > 0:
                # Average across all tasks with this logical rule
                rule_avg_activity = np.mean(task_activities[task_indices], axis=0)
                logical_rule_averages.append(rule_avg_activity)
            else:
                # If no tasks for this rule, use zeros (shouldn't happen with 64 tasks)
                logical_rule_averages.append(np.zeros(128))
        
        # Convert to (4, 128) matrix
        logical_rule_matrix = np.array(logical_rule_averages)  # Shape: (4, 128)
        
        # Compute dimensionality using SVD
        logical_rule_dim[timepoint] = getDimensionality_SVD(logical_rule_matrix)
        
    return logical_rule_dim

def get_sensory_rule_dimensionality(hidden_file, trial_length=10):
    """
    Calculate sensory rule dimensionality by averaging across logical and motor rules.
    For each timepoint, groups the 64 tasks by sensory rule (0-3), averages within each group
    to get (4, 128) matrix, then computes participation ratio using SVD.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
    
    Returns:
        sensory_rule_dim: Array of dimensionality values across timepoints
    """
    hidden_activities = load_hidden_activities(hidden_file)
    sensory_rule_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get task-level activities: average across stimuli to get (64, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        activities_reshaped = timepoint_activities.reshape(64, 256, 128)
        task_activities = np.mean(activities_reshaped, axis=1)  # Shape: (64, 128)
        
        # Group tasks by sensory rule
        sensory_groups = group_tasks_by_rule_domain(hidden_activities, timepoint, 'sensory')
        
        # Average within each sensory rule group
        sensory_rule_averages = []
        
        for rule_value in range(4):  # 4 sensory rules (RED, VERTICAL, HI-PITCH, CONSTANT)
            task_indices = sensory_groups[rule_value]
            
            if len(task_indices) > 0:
                # Average across all tasks with this sensory rule
                rule_avg_activity = np.mean(task_activities[task_indices], axis=0)
                sensory_rule_averages.append(rule_avg_activity)
            else:
                # If no tasks for this rule, use zeros (shouldn't happen with 64 tasks)
                sensory_rule_averages.append(np.zeros(128))
        
        # Convert to (4, 128) matrix
        sensory_rule_matrix = np.array(sensory_rule_averages)  # Shape: (4, 128)
        
        # Compute dimensionality using SVD
        sensory_rule_dim[timepoint] = getDimensionality_SVD(sensory_rule_matrix)
        
    return sensory_rule_dim

def get_motor_rule_dimensionality(hidden_file, trial_length=10):
    """
    Calculate motor rule dimensionality by averaging across logical and sensory rules.
    For each timepoint, groups the 64 tasks by motor rule (0-3), averages within each group
    to get (4, 128) matrix, then computes participation ratio using SVD.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
    
    Returns:
        motor_rule_dim: Array of dimensionality values across timepoints
    """
    hidden_activities = load_hidden_activities(hidden_file)
    motor_rule_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get task-level activities: average across stimuli to get (64, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        activities_reshaped = timepoint_activities.reshape(64, 256, 128)
        task_activities = np.mean(activities_reshaped, axis=1)  # Shape: (64, 128)
        
        # Group tasks by motor rule
        motor_groups = group_tasks_by_rule_domain(hidden_activities, timepoint, 'motor')
        
        # Average within each motor rule group
        motor_rule_averages = []
        
        for rule_value in range(4):  # 4 motor rules (LIND, LMID, RIND, RMID)
            task_indices = motor_groups[rule_value]
            
            if len(task_indices) > 0:
                # Average across all tasks with this motor rule
                rule_avg_activity = np.mean(task_activities[task_indices], axis=0)
                motor_rule_averages.append(rule_avg_activity)
            else:
                # If no tasks for this rule, use zeros (shouldn't happen with 64 tasks)
                motor_rule_averages.append(np.zeros(128))
        
        # Convert to (4, 128) matrix
        motor_rule_matrix = np.array(motor_rule_averages)  # Shape: (4, 128)
        
        # Compute dimensionality using SVD
        motor_rule_dim[timepoint] = getDimensionality_SVD(motor_rule_matrix)
        
    return motor_rule_dim

def get_all_rule_domain_dimensionalities(hidden_file, trial_length=10):
    """
    Compute all three rule domain dimensionalities for comparison.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
    
    Returns:
        Dictionary with all three rule domain dimensionality traces
    """
    return {
        'logical_rule_dim': get_logical_rule_dimensionality(hidden_file, trial_length),
        'sensory_rule_dim': get_sensory_rule_dimensionality(hidden_file, trial_length),
        'motor_rule_dim': get_motor_rule_dimensionality(hidden_file, trial_length)
    }

def test_rule_domain_grouping(hidden_file, timepoint=0):
    """
    Test function to verify rule domain grouping is working correctly.
    Prints out which tasks belong to each rule group for manual verification.
    
    Args:
        hidden_file: Path to hidden activities file
        timepoint: Timepoint to examine (default: 0)
    """
    hidden_activities = load_hidden_activities(hidden_file)
    
    print(f"Testing rule domain grouping for timepoint {timepoint}")
    print("="*60)
    
    # Test all three rule domains
    for rule_domain in ['logical', 'sensory', 'motor']:
        print(f"\n{rule_domain.upper()} RULE GROUPING:")
        print("-" * 30)
        
        rule_groups = group_tasks_by_rule_domain(hidden_activities, timepoint, rule_domain)
        
        for rule_value in range(4):
            task_indices = rule_groups[rule_value]
            print(f"Rule {rule_value}: {len(task_indices)} tasks")
            
            # Show first few tasks as examples
            if len(task_indices) > 0:
                print(f"  Task indices: {task_indices[:5]}{'...' if len(task_indices) > 5 else ''}")
                
                # Show the actual task specifications for first few
                task_info = hidden_activities['task_info'][timepoint]
                print("  Example tasks:")
                for i, task_idx in enumerate(task_indices[:3]):
                    task = task_info[task_idx]
                    print(f"    Task {task_idx}: L{task['logical_ctx']} S{task['sensory_ctx']} M{task['motor_ctx']}")
            else:
                print(f"  WARNING: No tasks found!")
        
        # Verify total count
        total_tasks = sum(len(rule_groups[rv]) for rv in range(4))
        print(f"Total tasks across all {rule_domain} rules: {total_tasks}/64")
    
    print("="*60)

### DIMENSIONALITY BY TASK TYPE ###

def get_stimulus_dimensionality(hidden_file, trial_length=10, collapse_only=False, integration_only=False):
    """
    FIXED VERSION: Calculate stimulus dimensionality by averaging across task dimension and applying SVD.
    For each timepoint, averages activities across tasks to get stimulus representations,
    then computes participation ratio using SVD.
    
    Key fix: Filter by individual (task, stimulus) combinations, not just stimuli.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
        collapse_only: If True, only use collapse trials
        integration_only: If True, only use integration trials
    
    Returns:
        stimulus_dim: Array of dimensionality values across timepoints
    """
    from cpro import CPRO
    env = CPRO(training_mode="minimal")
    
    hidden_activities = load_hidden_activities(hidden_file)
    stimulus_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get stimulus-level activities: (64*256, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        
        if collapse_only or integration_only:
            # Filter by individual (task, stimulus) combinations
            stimulus_info = hidden_activities['stimulus_info'][timepoint]
            
            # Create mask for which trials to include
            trial_mask = np.zeros(len(stimulus_info), dtype=bool)
            
            for trial_idx, trial_info in enumerate(stimulus_info):
                task = trial_info['task']
                stimulus = trial_info['stimulus']['stimulus']
                
                is_collapse = can_decide_after_stim1(task, stimulus)
                
                if collapse_only and is_collapse:
                    trial_mask[trial_idx] = True
                elif integration_only and not is_collapse:
                    trial_mask[trial_idx] = True
            
            if trial_mask.sum() < 2:  # Need at least 2 trials
                stimulus_dim[timepoint] = 0.0
                continue
                
            # Extract stimulus indices for filtered trials
            stimulus_indices = []
            for trial_idx in np.where(trial_mask)[0]:
                trial_info = stimulus_info[trial_idx]
                stim_idx = trial_info['stimulus']['stim_idx']
                stimulus_indices.append(stim_idx)
            
            # Group by stimulus and average across remaining tasks
            unique_stimuli = np.unique(stimulus_indices)
            stimulus_averages = []
            
            for stim_idx in unique_stimuli:
                # Find all filtered trials with this stimulus
                matching_trials = []
                for trial_idx in np.where(trial_mask)[0]:
                    trial_info = stimulus_info[trial_idx] 
                    if trial_info['stimulus']['stim_idx'] == stim_idx:
                        matching_trials.append(trial_idx)
                
                if len(matching_trials) > 0:
                    # Average across all tasks for this stimulus
                    avg_activity = np.mean(timepoint_activities[matching_trials], axis=0)
                    stimulus_averages.append(avg_activity)
            
            if len(stimulus_averages) < 2:
                stimulus_dim[timepoint] = 0.0
                continue
                
            stimulus_activities = np.array(stimulus_averages)
            
        else:
            # Original logic: reshape and average across all tasks
            activities_reshaped = timepoint_activities.reshape(64, 256, 128)
            stimulus_activities = np.mean(activities_reshaped, axis=0)  # Shape: (256, 128)
        
        # Compute dimensionality using SVD
        stimulus_dim[timepoint] = getDimensionality_SVD(stimulus_activities)
        
    return stimulus_dim

def get_global_dimensionality(hidden_file, trial_length=10):
    """
    Calculate global dimensionality using all 16,384 (64 tasks × 256 stimuli) activities.
    Computes participation ratio using SVD directly on the activity matrix for efficiency.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
    
    Returns:
        global_dim: Array of dimensionality values across timepoints
    """
    hidden_activities = load_hidden_activities(hidden_file)
    
    global_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get all activities for this timepoint: shape (16384, 128)
        all_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        
        # Compute dimensionality using SVD
        global_dim[timepoint] = getDimensionality_SVD(all_activities)
    
    return global_dim

def get_motor_response_dimensionality(hidden_file, trial_length=10, collapse_only=False, integration_only=False):
    """
    Calculate motor response dimensionality using averaging-first approach.
    For each timepoint, groups trials by motor response, averages within each group
    to get (4, 128) matrix, then applies SVD to measure response space dimensionality.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
        collapse_only: If True, only use collapse trials
        integration_only: If True, only use integration trials
    
    Returns:
        motor_response_dim: Array of dimensionality values across timepoints
    """
    from cpro import CPRO
    env = CPRO(training_mode="minimal")
    
    hidden_activities = load_hidden_activities(hidden_file)
    motor_response_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get stimulus-level activities for this timepoint: (64*256, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        
        # Reconstruct responses and collapse info for each trial
        responses, _, collapse_mask = get_responses_for_hidden_activities(hidden_activities, env, timepoint)
        
        # Filter by collapse/integration if needed
        if collapse_only:
            mask = collapse_mask
            timepoint_activities = timepoint_activities[mask]
            responses = responses[mask]
        elif integration_only:
            mask = ~collapse_mask  # Integration trials are NOT collapse trials
            timepoint_activities = timepoint_activities[mask]
            responses = responses[mask]
        
        # Group by motor response and average within each group
        response_averages = []
        
        for motor_response in range(4):
            response_mask = (responses == motor_response)
            if response_mask.sum() > 0:  # Need at least 1 trial for this response
                # Average activities for this response type
                avg_activity = np.mean(timepoint_activities[response_mask], axis=0)
                response_averages.append(avg_activity)
        
        # Apply SVD to the (n_responses, 128) matrix
        if len(response_averages) > 1:  # Need at least 2 response types for meaningful dimensionality
            response_matrix = np.array(response_averages)  # Shape: (n_responses, 128)
            motor_response_dim[timepoint] = getDimensionality_SVD(response_matrix)
        else:
            motor_response_dim[timepoint] = 0.0
    
    return motor_response_dim

def get_rule_satisfaction_dimensionality(hidden_file, trial_length=10, collapse_only=False, integration_only=False):
    """
    Calculate rule satisfaction dimensionality using averaging-first approach.
    For each timepoint, groups trials by rule satisfaction (True/False), averages within each group
    to get (2, 128) matrix, then applies SVD to measure rule satisfaction space dimensionality.
    
    ENHANCED: Now supports filtering by collapse/integration trials.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
        collapse_only: If True, only use collapse trials
        integration_only: If True, only use integration trials
    
    Returns:
        rule_satisfaction_dim: Array of dimensionality values across timepoints
    """
    from cpro import CPRO
    env = CPRO(training_mode="minimal")
    
    hidden_activities = load_hidden_activities(hidden_file)
    rule_satisfaction_dim = np.zeros(trial_length)
    
    for timepoint in range(trial_length):
        # Get stimulus-level activities for this timepoint: (64*256, 128)
        timepoint_activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        
        # Reconstruct rule satisfaction and collapse info for each trial
        _, rule_satisfied, collapse_mask = get_responses_for_hidden_activities(hidden_activities, env, timepoint)
        
        # Filter by collapse/integration if needed
        if collapse_only:
            mask = collapse_mask
            timepoint_activities = timepoint_activities[mask]
            rule_satisfied = rule_satisfied[mask]
        elif integration_only:
            mask = ~collapse_mask  # Integration trials are NOT collapse trials
            timepoint_activities = timepoint_activities[mask]
            rule_satisfied = rule_satisfied[mask]
        
        # Check if we have enough trials
        if len(timepoint_activities) < 2:
            rule_satisfaction_dim[timepoint] = 0.0
            continue
        
        # Group by rule satisfaction and average within each group
        satisfaction_averages = []
        
        for satisfied in [True, False]:
            satisfaction_mask = (rule_satisfied == satisfied)
            if satisfaction_mask.sum() > 0:  # Need at least 1 trial for this satisfaction type
                # Average activities for this satisfaction type
                avg_activity = np.mean(timepoint_activities[satisfaction_mask], axis=0)
                satisfaction_averages.append(avg_activity)
        
        # Apply SVD to the (n_satisfaction_types, 128) matrix
        if len(satisfaction_averages) == 2:  # Should have both True and False
            satisfaction_matrix = np.array(satisfaction_averages)  # Shape: (2, 128)
            rule_satisfaction_dim[timepoint] = getDimensionality_SVD(satisfaction_matrix)
        else:
            # If we only have one satisfaction type, dimensionality is effectively 0
            # (no variance in rule satisfaction dimension)
            rule_satisfaction_dim[timepoint] = 0.0
    
    return rule_satisfaction_dim

def get_all_dimensionalities(hidden_file, trial_length=10):
    """
    Compute all dimensionality measures for comparison using consistent SVD approach.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze (default: 10)
    
    Returns:
        Dict with all dimensionality traces
    """
    return {
        'global_dim': get_global_dimensionality(hidden_file, trial_length),
        'task_dim': get_task_dimensionality(hidden_file, trial_length),
        'stimulus_dim': get_stimulus_dimensionality(hidden_file, trial_length),
        'stimulus_collapse_dim': get_stimulus_dimensionality(hidden_file, trial_length, collapse_only=True),
        'stimulus_integration_dim': get_stimulus_dimensionality(hidden_file, trial_length, integration_only=True),
        'motor_response_dim': get_motor_response_dimensionality(hidden_file, trial_length),
        'motor_response_collapse_dim': get_motor_response_dimensionality(hidden_file, trial_length, collapse_only=True),
        'motor_response_integration_dim': get_motor_response_dimensionality(hidden_file, trial_length, integration_only=True),
        'rule_satisfaction_dim': get_rule_satisfaction_dimensionality(hidden_file, trial_length)
    }

# Legacy function name for backward compatibility
def get_hidden_rep_dim(hidden_file, trial_length=10):
    """
    LEGACY: Calculate task dimensionality (renamed for clarity).
    Use get_task_dimensionality() instead.
    """
    return get_task_dimensionality(hidden_file, trial_length)

def get_collapse_integration_dimensionality(hidden_file, trial_length=10):
    """
    LEGACY: Calculate stimulus dimensionality for collapse vs integration trials.
    Use get_stimulus_dimensionality() with appropriate flags instead.
    
    Returns:
        Dict with 'collapse' and 'integration' dimensionality traces
    """
    return {
        'collapse': get_stimulus_dimensionality(hidden_file, trial_length, collapse_only=True),
        'integration': get_stimulus_dimensionality(hidden_file, trial_length, integration_only=True),
        'task_counts': get_trial_counts_summary_for_collapse_integration()
    }

def get_trial_counts_summary_for_collapse_integration():
    """
    Helper function to get trial counts for collapse vs integration analysis.
    """
    from cpro import CPRO
    env = CPRO(training_mode="minimal")
    
    collapse_stimuli = set()
    integration_stimuli = set()
    
    for stim_idx, stimulus in enumerate(env.all_stim_combinations):
        can_collapse = False
        for task in env.test_tasks:
            if can_decide_after_stim1(task, stimulus):
                can_collapse = True
                break
        
        if can_collapse:
            collapse_stimuli.add(stim_idx)
        else:
            integration_stimuli.add(stim_idx)
    
    return {
        'tasks_with_collapse': len(collapse_stimuli),
        'tasks_with_integration': len(integration_stimuli)
    }

def get_responses_for_hidden_activities(hidden_activities, env, timepoint=0):
    """
    Reconstruct the correct responses for each trial in the hidden activities.
    
    Args:
        hidden_activities: Loaded hidden activities from the .pt file
        env: CPRO environment instance
        timepoint: Which timepoint to use for stimulus info (default: 0)
    
    Returns:
        responses: Array of shape (64*256,) with correct response for each trial
        rule_satisfied: Array of shape (64*256,) with True/False for rule satisfaction
        collapse_mask: Array of shape (64*256,) with True for collapse trials
    """
    # Get stimulus info for the specified timepoint
    stimulus_info = hidden_activities['stimulus_info'][timepoint]
    
    responses = []
    rule_satisfied = []
    collapse_mask = []
    
    for trial_info in stimulus_info:
        task = trial_info['task']
        stimulus = trial_info['stimulus']['stimulus']  # The actual stim combination
        
        # Use CPRO methods to determine correct response and rule satisfaction
        rule_met = env._evaluate_rule(stimulus, task)
        response = env._get_motor_response(task, rule_met)
        
        # Determine if this is a collapse trial
        is_collapse = can_decide_after_stim1(task, stimulus)
        
        responses.append(response)
        rule_satisfied.append(rule_met)
        collapse_mask.append(is_collapse)
    
    return np.array(responses), np.array(rule_satisfied), np.array(collapse_mask)

def get_maximal_training_indices(hidden_activities, timepoint=0):
    """
    Get indices and labels of tasks that correspond to 'maximal' training mode.
    Maximal training excludes the 4 diagonal tasks: {0,0,0}, {1,1,1}, {2,2,2}, {3,3,3}
    
    Args:
        hidden_activities: Loaded hidden activities from the .pt file
        timepoint: Which timepoint to use for task info (default: 0, since task info is same across timepoints)
    
    Returns:
        Tuple of (indices, labels) where:
        - indices: List of indices corresponding to maximal training tasks
        - labels: List of task labels in format 'L{logical}S{sensory}M{motor}'
    """
    # Define the 4 diagonal tasks that are excluded from maximal training
    diagonal_tasks = [
        {'logical_ctx': 0, 'sensory_ctx': 0, 'motor_ctx': 0},
        {'logical_ctx': 1, 'sensory_ctx': 1, 'motor_ctx': 1},
        {'logical_ctx': 2, 'sensory_ctx': 2, 'motor_ctx': 2},
        {'logical_ctx': 3, 'sensory_ctx': 3, 'motor_ctx': 3}
    ]
    
    # Get task info for the specified timepoint
    task_info = hidden_activities['task_info'][timepoint]
    
    # Find indices of tasks that are NOT diagonal tasks
    maximal_indices = []
    maximal_labels = []
    
    for i, task in enumerate(task_info):
        if task not in diagonal_tasks:
            maximal_indices.append(i)
            # Create label in format L{logical}S{sensory}M{motor}
            label = f"L{task['logical_ctx']}S{task['sensory_ctx']}M{task['motor_ctx']}"
            maximal_labels.append(label)
    
    return maximal_indices, maximal_labels

def create_theoretical_rsms(maximal_labels):
    """
    Create three theoretical representational similarity matrices (RSMs) for CPRO tasks.
    
    Args:
        maximal_labels: List of task labels in format 'L{logical}S{sensory}M{motor}' 
                       (e.g., from get_maximal_training_indices function)
    
    Returns:
        Dictionary containing three theoretical RSMs:
        - 'overlap': Simple overlap count (0, 1, 2, 3)
        - 'binary_compositional': Binary compositional (0 or 1)
        - 'hierarchical': Hierarchical weighting (logical > sensory > motor)
    """
    n_tasks = len(maximal_labels)
    
    # Parse task labels to extract rule contexts
    tasks = []
    for label in maximal_labels:
        # Extract L, S, M values from format 'L{logical}S{sensory}M{motor}'
        logical = int(label[1])
        sensory = int(label[3]) 
        motor = int(label[5])
        tasks.append({'logical': logical, 'sensory': sensory, 'motor': motor})
    
    # Initialize RSMs
    overlap_rsm = np.zeros((n_tasks, n_tasks))
    binary_rsm = np.zeros((n_tasks, n_tasks))
    hierarchical_rsm = np.zeros((n_tasks, n_tasks))
    
    # Fill RSMs
    for i in range(n_tasks):
        for j in range(n_tasks):
            task_i = tasks[i]
            task_j = tasks[j]
            
            # Count overlaps
            same_logical = int(task_i['logical'] == task_j['logical'])
            same_sensory = int(task_i['sensory'] == task_j['sensory'])
            same_motor = int(task_i['motor'] == task_j['motor'])
            
            overlap_count = same_logical + same_sensory + same_motor
            
            # 1. Simple overlap RSM (0, 1, 2, 3)
            overlap_rsm[i, j] = overlap_count
            
            # 2. Binary compositional RSM (0 if no overlap, 1 if any overlap)
            binary_rsm[i, j] = 1 if overlap_count > 0 else 0
            
            # 3. Hierarchical RSM (logical=4, sensory=2, motor=1)
            hierarchical_rsm[i, j] = 4*same_logical + 2*same_sensory + 1*same_motor
    
    return {
        'overlap': overlap_rsm,
        'binary_compositional': binary_rsm,
        'hierarchical': hierarchical_rsm
    }

def get_maximal_similarity_matrix(init_scale, training_mode, seed, timepoint, results_dir, optimizer):
    """
    Get similarity matrix for maximal training tasks at a specific timepoint.
    """
    results_path = results_dir + 'cpu_experiment_' + training_mode + '_scale' + str(init_scale) + '_' + optimizer + '_seed' + str(seed) + '/'
    hidden_activities_file = 'hidden_scale' + str(init_scale) + '_' + optimizer + '_seed' + str(seed) + '.pt'
    hidden_file = Path(results_path + hidden_activities_file)

    hidden_activities = load_hidden_activities(hidden_file)
    maximal_indices, maximal_labels = get_maximal_training_indices(hidden_activities, timepoint=timepoint)

    timepoint_activities = hidden_activities['hidden_states'][timepoint]
    timepoint_task_info = hidden_activities['task_info'][timepoint]

    # Convert to numpy arrays
    numpy_activities = []
    for activity in timepoint_activities:
        numpy_activities.append(activity.numpy())

    hidden_states = np.array(numpy_activities, dtype=object)
    maximal_training_hidden_states = hidden_states[maximal_indices, :]
    similarity_matrix = cosine_similarity(maximal_training_hidden_states)
    
    return similarity_matrix

def can_decide_after_stim1(task: Dict, stimulus: Dict) -> bool:
    """
    Determine if a decision can be made after seeing only the first stimulus.
    
    Args:
        task: Dict with 'logical_ctx', 'sensory_ctx', 'motor_ctx'
        stimulus: Dict with 'stim1' and 'stim2' containing stimulus features
    
    Returns:
        True if decision can be made after Stim1, False if need to wait for Stim2
    """
    # Get the relevant stimulus dimension based on sensory context
    dim_map = {
        0: 'VDim1',   # RED
        1: 'VDim2',   # VERTICAL  
        2: 'ADim1',   # HI-PITCH
        3: 'ADim2'    # CONSTANT
    }
    
    relevant_dim = dim_map[task['sensory_ctx']]
    stim1_has_feature = stimulus['stim1'][relevant_dim] == 1  # Target value is 1
    
    logical_ctx = task['logical_ctx']
    
    if logical_ctx == 0:  # AND: "Both must have feature"
        # If Stim1 lacks feature, rule is violated -> early collapse
        return not stim1_has_feature
        
    elif logical_ctx == 1:  # NAND: "Not both have feature"  
        # If Stim1 lacks feature, rule is satisfied -> early collapse
        return not stim1_has_feature
        
    elif logical_ctx == 2:  # OR: "Either has feature"
        # If Stim1 has feature, rule is satisfied -> early collapse
        return stim1_has_feature
        
    elif logical_ctx == 3:  # NOR: "Neither has feature"
        # If Stim1 has feature, rule is violated -> early collapse
        return stim1_has_feature
    
    else:
        raise ValueError(f"Unknown logical context: {logical_ctx}")

def categorize_all_trials(env) -> Tuple[List[int], List[int]]:
    """
    Categorize all (task, stimulus) combinations into collapse vs integration.
    
    Args:
        env: CPRO environment instance
        
    Returns:
        Tuple of (collapse_indices, integration_indices) where each is a list of 
        linear indices into the flattened (64*256) trial space
    """
    collapse_indices = []
    integration_indices = []
    
    for task_idx, task in enumerate(env.test_tasks):
        for stim_idx, stimulus in enumerate(env.all_stim_combinations):
            # Calculate linear index in the flattened space
            linear_idx = task_idx * 256 + stim_idx
            
            if can_decide_after_stim1(task, stimulus):
                collapse_indices.append(linear_idx)
            else:
                integration_indices.append(linear_idx)
    
    return collapse_indices, integration_indices

def get_trial_counts_summary(env):
    """
    Get basic trial counts for collapse vs integration for sanity checking.
    
    Args:
        env: CPRO environment instance
        
    Returns:
        Dict with summary of trial counts across all tasks
    """
    total_collapse = 0
    total_integration = 0
    
    for task in env.test_tasks:
        for stimulus in env.all_stim_combinations:
            if can_decide_after_stim1(task, stimulus):
                total_collapse += 1
            else:
                total_integration += 1
    
    return {
        'total_collapse': total_collapse,
        'total_integration': total_integration,
        'total_trials': total_collapse + total_integration,
        'collapse_fraction': total_collapse / (total_collapse + total_integration)
    }

#####################################################################################
# DECODING FUNCTIONS #
#####################################################################################

# Add these functions to analysis_utils.py

import numpy as np
import os
import sys
from datetime import datetime
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from typing import Dict, List, Tuple, Optional, Union
import warnings

def log_message(message: str, level: str = "INFO"):
    """
    Print timestamped log message to stderr (for SLURM compatibility).
    
    Args:
        message: Message to log
        level: Log level (INFO, SUCCESS, ERROR, PROGRESS)
    """
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    # Use simple symbols that work in all terminals
    symbols = {
        'INFO': '[INFO]',
        'SUCCESS': '[SUCCESS]', 
        'ERROR': '[ERROR]',
        'PROGRESS': '[PROGRESS]',
        'SUMMARY': '[SUMMARY]'
    }
    
    symbol = symbols.get(level, '[INFO]')
    print(f"[{timestamp}] {symbol} {message}", file=sys.stderr, flush=True)

class CPRODecoder:
    """
    Core decoder class for CPRO decoding analyses using SVM with cross-validation
    and optional permutation testing.
    """
    
    def __init__(self, decoder_type='svm', cv_folds=10, n_permutations=1000, 
                 random_state=42, svm_params=None, run_permutation_tests=True):
        """
        Initialize decoder.
        
        Args:
            decoder_type: Type of decoder ('svm' only for now)
            cv_folds: Number of CV folds
            n_permutations: Number of permutations for significance testing
            random_state: Random seed
            svm_params: SVM parameters dict (default: {'C': 1.0, 'kernel': 'linear'})
            run_permutation_tests: Whether to run permutation tests (default: True)
        """
        self.decoder_type = decoder_type
        self.cv_folds = cv_folds
        self.n_permutations = n_permutations
        self.random_state = random_state
        self.svm_params = svm_params or {'C': 1.0, 'kernel': 'linear'}
        self.run_permutation_tests = run_permutation_tests
        
    def fit_and_evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict:
        """
        Fit decoder and evaluate using cross-validation.
        
        Args:
            X: Feature matrix (n_samples, n_features)
            y: Labels (n_samples,)
            
        Returns:
            Dict with accuracy metrics and fitted model
        """
        # Remove any NaN values
        mask = ~(np.isnan(X).any(axis=1) | np.isnan(y))
        X_clean = X[mask]
        y_clean = y[mask]
        
        if len(X_clean) < self.cv_folds:
            return {
                'success': False,
                'reason': f'Insufficient samples: {len(X_clean)} < {self.cv_folds}',
                'n_samples': len(X_clean)
            }
        
        # Check if we have multiple classes
        unique_classes = np.unique(y_clean)
        if len(unique_classes) < 2:
            return {
                'success': False,
                'reason': f'Need at least 2 classes, got {len(unique_classes)}',
                'n_samples': len(X_clean),
                'unique_classes': unique_classes
            }
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X_clean)
        
        try:
            # Create decoder
            if self.decoder_type == 'svm':
                decoder = SVC(
                    random_state=self.random_state, 
                    decision_function_shape='ovr',  # One-vs-Rest for multiclass
                    **self.svm_params
                )
            else:
                raise ValueError(f"Unknown decoder type: {self.decoder_type}")
            
            # Cross-validation
            cv = StratifiedKFold(n_splits=self.cv_folds, shuffle=True, 
                               random_state=self.random_state)
            cv_scores = cross_val_score(decoder, X_scaled, y_clean, cv=cv, scoring='accuracy')
            
            # Fit on full data for additional metrics
            decoder.fit(X_scaled, y_clean)
            y_pred = decoder.predict(X_scaled)
            train_accuracy = accuracy_score(y_clean, y_pred)
            
            # Calculate chance level
            chance_level = 1.0 / len(unique_classes)
            
            results = {
                'success': True,
                'cv_accuracy_mean': cv_scores.mean(),
                'cv_accuracy_std': cv_scores.std(),
                'cv_scores': cv_scores,
                'train_accuracy': train_accuracy,
                'chance_level': chance_level,
                'n_samples': len(X_clean),
                'n_features': X_scaled.shape[1],
                'n_classes': len(unique_classes),
                'unique_classes': unique_classes,
                'decoder': decoder,
                'scaler': scaler
            }
            
            return results
            
        except Exception as e:
            return {
                'success': False,
                'reason': str(e),
                'n_samples': len(X_clean)
            }
    
    def permutation_test(self, X: np.ndarray, y: np.ndarray) -> Dict:
        """
        Run permutation test for significance.
        
        Args:
            X: Feature matrix (n_samples, n_features)
            y: Labels (n_samples,)
            
        Returns:
            Dict with permutation test results or skipped message
        """
        # Check if permutation tests should be skipped
        if not self.run_permutation_tests:
            return {
                'success': True,
                'skipped': True,
                'reason': 'Permutation tests disabled',
                'observed_accuracy': None,
                'p_value': None,
                'significant': None
            }
        
        # Get observed accuracy
        observed_results = self.fit_and_evaluate(X, y)
        if not observed_results['success']:
            return {
                'success': False,
                'reason': observed_results['reason']
            }
        
        observed_accuracy = observed_results['cv_accuracy_mean']
        
        # Generate null distribution with progress bar
        null_accuracies = []
        
        # Progress bar for permutations
        with tqdm(total=self.n_permutations, desc="Permutations", unit="perm",
                  file=sys.stderr, leave=False, dynamic_ncols=True) as pbar:
            
            for perm in range(self.n_permutations):
                # Shuffle labels
                np.random.seed(self.random_state + perm)  # Reproducible but different seeds
                y_permuted = np.random.permutation(y)
                
                # Test permuted labels
                perm_results = self.fit_and_evaluate(X, y_permuted)
                
                if perm_results['success']:
                    null_accuracies.append(perm_results['cv_accuracy_mean'])
                else:
                    # If permutation fails, use chance level
                    null_accuracies.append(observed_results['chance_level'])
                
                # Update progress bar
                if perm % 50 == 0 or perm == self.n_permutations - 1:
                    current_mean = np.mean(null_accuracies) if null_accuracies else 0
                    pbar.set_postfix({
                        'Null_μ': f'{current_mean:.3f}',
                        'Obs': f'{observed_accuracy:.3f}'
                    })
                
                pbar.update(1)
        
        # Calculate p-value (one-tailed: observed >= null)
        null_accuracies = np.array(null_accuracies)
        p_value = np.mean(null_accuracies >= observed_accuracy)
        
        return {
            'success': True,
            'observed_accuracy': observed_accuracy,
            'null_accuracies': null_accuracies,
            'null_mean': null_accuracies.mean(),
            'null_std': null_accuracies.std(),
            'p_value': p_value,
            'significant': p_value < 0.05,
            'chance_level': observed_results['chance_level']
        }

# =============================================================================
# LABEL EXTRACTION FUNCTIONS
# =============================================================================

def get_rule_satisfaction_labels(hidden_activities: Dict, env, timepoint: int = 0) -> np.ndarray:
    """
    Extract rule satisfaction labels (True/False) for each trial.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        env: CPRO environment instance
        timepoint: Which timepoint to use for stimulus info
        
    Returns:
        Array of shape (64*256,) with True/False for rule satisfaction
    """
    _, rule_satisfied, _ = get_responses_for_hidden_activities(hidden_activities, env, timepoint)
    return rule_satisfied.astype(int)  # Convert to 0/1 for decoding

def get_motor_response_labels(hidden_activities: Dict, env, timepoint: int = 0, 
                            response_type: str = 'all') -> np.ndarray:
    """
    Extract motor response labels.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        env: CPRO environment instance  
        timepoint: Which timepoint to use for stimulus info
        response_type: 'all' (4-way), 'hand' (left=0, right=1), 
                      'left_fingers' (LIND=0, LMID=1), 'right_fingers' (RIND=0, RMID=1)
        
    Returns:
        Array of labels for the specified response type
    """
    responses, _, _ = get_responses_for_hidden_activities(hidden_activities, env, timepoint)
    
    if response_type == 'all':
        return responses  # 0,1,2,3 for LIND,LMID,RIND,RMID
    
    elif response_type == 'hand':
        # Left hand (0,1) -> 0, Right hand (2,3) -> 1
        return (responses >= 2).astype(int)
    
    elif response_type == 'left_fingers':
        # Only keep left hand trials (responses 0,1), map to 0,1
        left_mask = responses < 2
        left_responses = responses[left_mask]
        return left_responses, left_mask  # Return mask for filtering trials
    
    elif response_type == 'right_fingers':
        # Only keep right hand trials (responses 2,3), map to 0,1
        right_mask = responses >= 2
        right_responses = (responses[right_mask] - 2)  # Map 2,3 -> 0,1
        return right_responses, right_mask  # Return mask for filtering trials
    
    else:
        raise ValueError(f"Unknown response_type: {response_type}")

def get_motor_relevance_masks(hidden_activities: Dict, timepoint: int = 0) -> Dict[str, np.ndarray]:
    """
    Get relevance masks for motor decoding based on motor context.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        timepoint: Which timepoint to use for task info
        
    Returns:
        Dict with relevance masks for different motor analyses
    """
    # Get motor contexts for all trials
    stimulus_info = hidden_activities['stimulus_info'][timepoint]
    motor_contexts = np.array([trial_info['task']['motor_ctx'] for trial_info in stimulus_info])
    
    return {
        'left_hand_relevant': motor_contexts < 2,    # Motor contexts 0,1
        'right_hand_relevant': motor_contexts >= 2,  # Motor contexts 2,3
        'left_fingers_relevant': motor_contexts < 2,  # Same as left_hand_relevant
        'right_fingers_relevant': motor_contexts >= 2  # Same as right_hand_relevant
    }

def get_task_component_labels(hidden_activities: Dict, timepoint: int = 0, 
                            component_type: str = 'all') -> np.ndarray:
    """
    Extract task component labels.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        timepoint: Which timepoint to use for task info
        component_type: 'all' (12-way), 'logical' (4-way), 'sensory' (4-way), 
                       'motor' (4-way), 'tasks_64' (64-way)
        
    Returns:
        Array of labels for the specified component type
    """
    stimulus_info = hidden_activities['stimulus_info'][timepoint]
    
    logical_contexts = np.array([trial_info['task']['logical_ctx'] for trial_info in stimulus_info])
    sensory_contexts = np.array([trial_info['task']['sensory_ctx'] for trial_info in stimulus_info])
    motor_contexts = np.array([trial_info['task']['motor_ctx'] for trial_info in stimulus_info])
    
    if component_type == 'logical':
        return logical_contexts
    
    elif component_type == 'sensory':
        return sensory_contexts
    
    elif component_type == 'motor':
        return motor_contexts
    
    elif component_type == 'all':
        # Combine all components into single 12-way classification
        # Logical: 0-3, Sensory: 4-7, Motor: 8-11
        combined_labels = np.zeros(len(stimulus_info), dtype=int)
        combined_labels = logical_contexts  # 0-3
        combined_labels = np.concatenate([
            logical_contexts,
            sensory_contexts + 4,  # 4-7
            motor_contexts + 8     # 8-11
        ])
        return combined_labels[:len(stimulus_info)]  # Take first N samples
        
    elif component_type == 'tasks_64':
        # Each unique (L,S,M) combination gets a unique label 0-63
        task_labels = logical_contexts * 16 + sensory_contexts * 4 + motor_contexts
        return task_labels
    
    else:
        raise ValueError(f"Unknown component_type: {component_type}")

def get_negation_labels(hidden_activities: Dict, timepoint: int = 0) -> np.ndarray:
    """
    Extract negation labels: AND/OR (0) vs NAND/NOR (1).
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        timepoint: Which timepoint to use for task info
        
    Returns:
        Array of shape (64*256,) with 0 for AND/OR, 1 for NAND/NOR
    """
    stimulus_info = hidden_activities['stimulus_info'][timepoint]
    logical_contexts = np.array([trial_info['task']['logical_ctx'] for trial_info in stimulus_info])
    
    # AND=0, NAND=1, OR=2, NOR=3
    # Non-negated: AND(0), OR(2) -> 0
    # Negated: NAND(1), NOR(3) -> 1
    return (logical_contexts % 2).astype(int)

def get_stimulus_labels(hidden_activities: Dict, timepoint: int = 0, 
                       stimulus_type: str = 'all_features') -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
    """
    Extract stimulus feature labels.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        timepoint: Which timepoint to use for stimulus info
        stimulus_type: 'all_features' (8-way), 'relevant_vs_irrelevant', 
                      'vdim1', 'vdim2', 'adim1', 'adim2'
        
    Returns:
        Array of labels, or (labels, relevance_mask) for relevant_vs_irrelevant
    """
    stimulus_info = hidden_activities['stimulus_info'][timepoint]
    
    if stimulus_type == 'all_features':
        # Extract all 8 stimulus features and create 8-way classification
        # This is complex - let's create a combined feature vector
        feature_labels = []
        
        for trial_info in stimulus_info:
            stim1 = trial_info['stimulus']['stimulus']['stim1']
            stim2 = trial_info['stimulus']['stimulus']['stim2']
            
            # Create 8-dimensional feature vector
            features = [
                stim1['VDim1'], stim1['VDim2'], stim1['ADim1'], stim1['ADim2'],
                stim2['VDim1'], stim2['VDim2'], stim2['ADim1'], stim2['ADim2']
            ]
            
            # Convert to single integer label (8-bit binary number)
            label = sum(bit * (2**i) for i, bit in enumerate(features))
            feature_labels.append(label)
        
        return np.array(feature_labels)
    
    elif stimulus_type in ['vdim1', 'vdim2', 'adim1', 'adim2']:
        # Extract specific dimension across both stimuli
        dim_map = {
            'vdim1': 'VDim1',
            'vdim2': 'VDim2', 
            'adim1': 'ADim1',
            'adim2': 'ADim2'
        }
        dim_key = dim_map[stimulus_type]
        
        # For binary classification, we can combine both stimuli
        # or focus on the relevant stimulus based on context
        stim1_values = []
        stim2_values = []
        
        for trial_info in stimulus_info:
            stim1 = trial_info['stimulus']['stimulus']['stim1']
            stim2 = trial_info['stimulus']['stimulus']['stim2']
            stim1_values.append(stim1[dim_key])
            stim2_values.append(stim2[dim_key])
        
        # For now, let's combine as (stim1_val, stim2_val) -> 4-way classification
        combined_labels = np.array(stim1_values) * 2 + np.array(stim2_values)
        return combined_labels
    
    elif stimulus_type == 'relevant_vs_irrelevant':
        # This requires knowing which dimension is relevant for each trial
        relevant_labels = []
        irrelevant_labels = []
        
        for trial_info in stimulus_info:
            task = trial_info['task']
            stimulus = trial_info['stimulus']['stimulus']
            
            # Get relevant dimension based on sensory context
            dim_map = {0: 'VDim1', 1: 'VDim2', 2: 'ADim1', 3: 'ADim2'}
            relevant_dim = dim_map[task['sensory_ctx']]
            
            # Get relevant stimulus values
            relevant_val1 = stimulus['stim1'][relevant_dim]
            relevant_val2 = stimulus['stim2'][relevant_dim]
            relevant_combined = relevant_val1 * 2 + relevant_val2
            relevant_labels.append(relevant_combined)
            
            # Get one irrelevant dimension (let's use the next dimension cyclically)
            irrelevant_dim_idx = (task['sensory_ctx'] + 1) % 4
            irrelevant_dim = dim_map[irrelevant_dim_idx]
            irrelevant_val1 = stimulus['stim1'][irrelevant_dim]
            irrelevant_val2 = stimulus['stim2'][irrelevant_dim]
            irrelevant_combined = irrelevant_val1 * 2 + irrelevant_val2
            irrelevant_labels.append(irrelevant_combined)
        
        return np.array(relevant_labels), np.array(irrelevant_labels)
    
    else:
        raise ValueError(f"Unknown stimulus_type: {stimulus_type}")

def get_collapse_integration_labels(hidden_activities: Dict, env, timepoint: int = 0) -> np.ndarray:
    """
    Extract collapse vs integration labels.
    
    Args:
        hidden_activities: Loaded hidden activities from .pt file
        env: CPRO environment instance
        timepoint: Which timepoint to use for stimulus info
        
    Returns:
        Array of shape (64*256,) with 0 for integration, 1 for collapse
    """
    _, _, collapse_mask = get_responses_for_hidden_activities(hidden_activities, env, timepoint)
    return collapse_mask.astype(int)

# =============================================================================
# HIGH-LEVEL DECODING ANALYSIS FUNCTIONS  
# =============================================================================

def run_single_decoding_analysis(hidden_file: Path, analysis_type: str, 
                                decoder_config: Dict, env=None, verbose: bool = False) -> Dict:
    """
    Run a single decoding analysis across all timepoints with SLURM-friendly logging.
    
    Args:
        hidden_file: Path to hidden activities file
        analysis_type: Type of analysis to run
        decoder_config: Configuration for decoder
        env: CPRO environment (created if None)
        verbose: Whether to print detailed progress
        
    Returns:
        Dict with results across all timepoints
    """
    if env is None:
        from cpro import CPRO
        env = CPRO(training_mode="minimal")
    
    # Load hidden activities
    hidden_activities = load_hidden_activities(hidden_file)
    
    # Initialize decoder
    decoder = CPRODecoder(**decoder_config)
    
    # Results storage
    results = {
        'analysis_type': analysis_type,
        'decoder_config': decoder_config,
        'timepoint_results': {},
        'summary': {}
    }
    
    trial_length = len(hidden_activities['stimulus_hidden_states'])
    timepoint_labels = ['Rule', 'Stim1_e', 'Stim1_l', 'Dly1_e', 'Dly1_l', 
                       'Stim2_e', 'Stim2_l', 'Dly2_e', 'Dly2_l', 'Resp']
    
    # Track timepoint progress
    timepoint_times = []
    timepoint_accs = []
    
    for timepoint in range(trial_length):
        tp_label = timepoint_labels[timepoint] if timepoint < len(timepoint_labels) else f"T{timepoint}"
        tp_start = datetime.now()
        
        # Get activities for this timepoint
        activities = np.array(hidden_activities['stimulus_hidden_states'][timepoint])
        
        # Get labels and decode
        timepoint_result = _get_labels_and_decode(
            activities, hidden_activities, env, timepoint, analysis_type, decoder
        )
        
        results['timepoint_results'][timepoint] = timepoint_result
        
        # Log timepoint completion
        tp_time = (datetime.now() - tp_start).total_seconds()
        timepoint_times.append(tp_time)
        
        # Extract accuracy for logging
        acc_str = "N/A"
        if isinstance(timepoint_result, dict) and 'cv_accuracy_mean' in timepoint_result:
            acc = timepoint_result['cv_accuracy_mean']
            chance = timepoint_result.get('chance_level', 0.5)
            acc_str = f"{acc:.3f} (vs {chance:.3f})"
            timepoint_accs.append(acc)
        elif isinstance(timepoint_result, dict):
            # Handle relevance analyses
            rel_accs = []
            for cond in ['relevant', 'irrelevant']:
                if cond in timepoint_result and 'cv_accuracy_mean' in timepoint_result[cond]:
                    rel_accs.append(timepoint_result[cond]['cv_accuracy_mean'])
            if rel_accs:
                acc_str = f"Rel:{rel_accs[0]:.3f}"
                if len(rel_accs) > 1:
                    acc_str += f" Irrel:{rel_accs[1]:.3f}"
        
        # Log every few timepoints or when done
        if timepoint % 3 == 0 or timepoint == trial_length - 1:
            progress = f"{timepoint+1}/{trial_length}"
            avg_time = np.mean(timepoint_times[-3:]) if len(timepoint_times) >= 3 else tp_time
            eta_mins = (avg_time * (trial_length - timepoint - 1)) / 60
            
            if eta_mins > 1:
                log_message(f"  Timepoints {progress}: {tp_label}({tp_time:.0f}s, Acc={acc_str}) - ETA {eta_mins:.1f}min", "PROGRESS")
            else:
                log_message(f"  Timepoints {progress}: {tp_label}({tp_time:.0f}s, Acc={acc_str})", "PROGRESS")
    
    # Create summary
    results['summary'] = _create_decoding_summary(results['timepoint_results'])
    
    # Log final timepoint summary
    if timepoint_accs:
        peak_acc = max(timepoint_accs)
        peak_tp = timepoint_accs.index(peak_acc)
        peak_label = timepoint_labels[peak_tp] if peak_tp < len(timepoint_labels) else f"T{peak_tp}"
        log_message(f"  Timepoint summary: Peak={peak_acc:.3f} at {peak_label}, Avg_time={np.mean(timepoint_times):.1f}s", "INFO")
    
    return results

def _get_labels_and_decode(activities: np.ndarray, hidden_activities: Dict, 
                          env, timepoint: int, analysis_type: str, 
                          decoder: CPRODecoder) -> Dict:
    """
    Helper function to get labels and run decoding for a specific analysis type.
    """
    
    if analysis_type == 'rule_satisfaction':
        labels = get_rule_satisfaction_labels(hidden_activities, env, timepoint)
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': 'binary', 'classes': ['violated', 'satisfied']}
        
    elif analysis_type == 'motor_response_all':
        labels = get_motor_response_labels(hidden_activities, env, timepoint, 'all')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '4-way', 'classes': ['LIND', 'LMID', 'RIND', 'RMID']}
        
    elif analysis_type == 'motor_response_hand':
        labels = get_motor_response_labels(hidden_activities, env, timepoint, 'hand')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': 'binary', 'classes': ['left_hand', 'right_hand']}
        
    elif analysis_type == 'motor_response_left_fingers':
        finger_labels, trial_mask = get_motor_response_labels(hidden_activities, env, timepoint, 'left_fingers')
        relevance_masks = get_motor_relevance_masks(hidden_activities, timepoint)
        
        # Run for relevant and irrelevant trials
        result = {}
        for relevance in ['relevant', 'irrelevant']:
            if relevance == 'relevant':
                mask = trial_mask & relevance_masks['left_fingers_relevant']
            else:
                mask = trial_mask & ~relevance_masks['left_fingers_relevant']
            
            if mask.sum() > 0:
                masked_activities = activities[mask]
                masked_labels = finger_labels  # Already filtered by trial_mask
                
                if len(masked_labels) == masked_activities.shape[0]:
                    decode_result = decoder.fit_and_evaluate(masked_activities, masked_labels)
                    result[relevance] = decode_result
                else:
                    result[relevance] = {'success': False, 'reason': 'Label/activity mismatch'}
            else:
                result[relevance] = {'success': False, 'reason': 'No trials in mask'}
        
        result['analysis_details'] = {'type': 'binary_with_relevance', 'classes': ['LIND', 'LMID']}
        
    elif analysis_type == 'motor_response_right_fingers':
        finger_labels, trial_mask = get_motor_response_labels(hidden_activities, env, timepoint, 'right_fingers')
        relevance_masks = get_motor_relevance_masks(hidden_activities, timepoint)
        
        # Run for relevant and irrelevant trials
        result = {}
        for relevance in ['relevant', 'irrelevant']:
            if relevance == 'relevant':
                mask = trial_mask & relevance_masks['right_fingers_relevant']
            else:
                mask = trial_mask & ~relevance_masks['right_fingers_relevant']
            
            if mask.sum() > 0:
                masked_activities = activities[mask]
                masked_labels = finger_labels  # Already filtered by trial_mask
                
                if len(masked_labels) == masked_activities.shape[0]:
                    decode_result = decoder.fit_and_evaluate(masked_activities, masked_labels)
                    result[relevance] = decode_result
                else:
                    result[relevance] = {'success': False, 'reason': 'Label/activity mismatch'}
            else:
                result[relevance] = {'success': False, 'reason': 'No trials in mask'}
        
        result['analysis_details'] = {'type': 'binary_with_relevance', 'classes': ['RIND', 'RMID']}
        
    elif analysis_type == 'task_components_logical':
        labels = get_task_component_labels(hidden_activities, timepoint, 'logical')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '4-way', 'classes': ['AND', 'NAND', 'OR', 'NOR']}
        
    elif analysis_type == 'task_components_sensory':
        labels = get_task_component_labels(hidden_activities, timepoint, 'sensory')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '4-way', 'classes': ['RED', 'VERTICAL', 'HI-PITCH', 'CONSTANT']}
        
    elif analysis_type == 'task_components_motor':
        labels = get_task_component_labels(hidden_activities, timepoint, 'motor')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '4-way', 'classes': ['LIND', 'LMID', 'RIND', 'RMID']}
        
    elif analysis_type == 'task_components_all':
        labels = get_task_component_labels(hidden_activities, timepoint, 'all')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '12-way', 'classes': 'all_rule_components'}
        
    elif analysis_type == 'tasks_64':
        labels = get_task_component_labels(hidden_activities, timepoint, 'tasks_64')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '64-way', 'classes': 'all_tasks'}
        
    elif analysis_type == 'negation':
        labels = get_negation_labels(hidden_activities, timepoint)
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': 'binary', 'classes': ['AND/OR', 'NAND/NOR']}
        
    elif analysis_type == 'collapse_integration':
        labels = get_collapse_integration_labels(hidden_activities, env, timepoint)
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': 'binary', 'classes': ['integration', 'collapse']}
        
    elif analysis_type == 'stimulus_all_features':
        labels = get_stimulus_labels(hidden_activities, timepoint, 'all_features')
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '256-way', 'classes': 'all_stimulus_combinations'}
        
    elif analysis_type.startswith('stimulus_'):
        # Handle individual stimulus dimensions
        dim_name = analysis_type.replace('stimulus_', '')
        labels = get_stimulus_labels(hidden_activities, timepoint, dim_name)
        result = decoder.fit_and_evaluate(activities, labels)
        result['analysis_details'] = {'type': '4-way', 'classes': f'{dim_name}_combinations'}
        
    else:
        result = {'success': False, 'reason': f'Unknown analysis_type: {analysis_type}'}
    
    return result

def _create_decoding_summary(timepoint_results: Dict) -> Dict:
    """
    Create summary statistics across timepoints.
    """
    summary = {
        'mean_accuracy_by_timepoint': {},
        'peak_accuracy': 0,
        'peak_timepoint': 0,
        'above_chance_timepoints': []
    }
    
    for timepoint, result in timepoint_results.items():
        if isinstance(result, dict) and 'cv_accuracy_mean' in result:
            # Simple analysis
            acc = result['cv_accuracy_mean']
            chance = result.get('chance_level', 0.5)
            
            summary['mean_accuracy_by_timepoint'][timepoint] = acc
            
            if acc > summary['peak_accuracy']:
                summary['peak_accuracy'] = acc
                summary['peak_timepoint'] = timepoint
            
            if acc > chance + 0.05:  # 5% above chance
                summary['above_chance_timepoints'].append(timepoint)
                
        elif isinstance(result, dict):
            # Analysis with relevance (motor fingers)
            for relevance in ['relevant', 'irrelevant']:
                if relevance in result and 'cv_accuracy_mean' in result[relevance]:
                    acc = result[relevance]['cv_accuracy_mean']
                    key = f'{relevance}_accuracy_by_timepoint'
                    if key not in summary:
                        summary[key] = {}
                    summary[key][timepoint] = acc
    
    return summary