#!/usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import product, combinations
import os
import random

# Add imports for CPRO if needed
try:
    from cpro import CPRO, CPROConfig
except ImportError:
    print("Warning: cpro module not found. Using simulated responses.")
    CPRO = None
    
import sys
sys.path.append('./src')
from analysis_utils import getDimensionality

def normalize_rsm(rsm):
    """
    Normalize RSM values between 0 and 1.
    
    Args:
        rsm: RSM matrix to normalize
    
    Returns:
        Normalized RSM with values between 0 and 1
    """
    rsm_min = rsm.min()
    rsm_max = rsm.max()
    
    if rsm_max - rsm_min == 0:
        # Handle case where all values are the same
        return np.zeros_like(rsm)
    
    normalized_rsm = (rsm - rsm_min) / (rsm_max - rsm_min)
    return normalized_rsm

def compute_rsm_dimensionalities(results):
    """
    Compute dimensionality for each RSM after normalizing between 0 and 1.
    
    Args:
        results: Dictionary containing the three RSMs from create_unified_rsms()
        
    Returns:
        Dictionary with dimensionality values for each RSM
    """
    print("\n4. Computing RSM Dimensionalities...")
    print("=" * 40)
    
    dimensionalities = {}
    rsm_types = ['task_rsm', 'stimulus_rsm', 'response_rsm']
    rsm_names = ['Task (Rule Overlap)', 'Stimulus (Feature Overlap)', 'Response (Same/Different)']
    
    for rsm_type, rsm_name in zip(rsm_types, rsm_names):
        rsm = results[rsm_type]
        
        # Normalize RSM between 0 and 1
        normalized_rsm = normalize_rsm(rsm)
        
        # Compute dimensionality using legacy function
        dimensionality = getDimensionality(normalized_rsm)
        dimensionalities[rsm_type] = dimensionality
        
        print(f"{rsm_name} RSM:")
        print(f"  Original range: {rsm.min():.2f} to {rsm.max():.2f}")
        print(f"  Normalized range: {normalized_rsm.min():.2f} to {normalized_rsm.max():.2f}")
        print(f"  Dimensionality: {dimensionality:.4f}")
        print("-" * 40)
    
    return dimensionalities

def select_strategic_tasks(n_tasks=12, seed=42):
    """
    Strategically select tasks to maximize rule overlap diversity.
    
    Args:
        n_tasks: Number of tasks to select (default: 12)
        seed: Random seed for reproducibility
        
    Returns:
        List of selected task dictionaries and their labels
    """
    np.random.seed(seed)
    random.seed(seed)
    
    # Generate all possible tasks (excluding diagonals for maximal training)
    all_tasks = []
    diagonal_tasks = [
        {'logical_ctx': i, 'sensory_ctx': i, 'motor_ctx': i} for i in range(4)
    ]
    
    for l in range(4):
        for s in range(4):
            for m in range(4):
                task = {'logical_ctx': l, 'sensory_ctx': s, 'motor_ctx': m}
                if task not in diagonal_tasks:
                    all_tasks.append(task)
    
    # Function to count rule overlaps between two tasks
    def count_overlaps(task1, task2):
        return sum([
            task1['logical_ctx'] == task2['logical_ctx'],
            task1['sensory_ctx'] == task2['sensory_ctx'],
            task1['motor_ctx'] == task2['motor_ctx']
        ])
    
    # Greedily select tasks to maximize overlap diversity
    selected_tasks = []
    
    # Start with a random task
    selected_tasks.append(random.choice(all_tasks))
    
    while len(selected_tasks) < n_tasks:
        best_task = None
        best_score = -1
        
        for candidate in all_tasks:
            if candidate in selected_tasks:
                continue
                
            # Calculate diversity score - want good mix of overlap levels
            overlap_counts = [0, 0, 0, 0]  # Count of 0, 1, 2, 3 overlaps
            
            for selected in selected_tasks:
                overlap = count_overlaps(candidate, selected)
                overlap_counts[overlap] += 1
            
            # Score favors more balanced distribution of overlaps
            # Penalize if we already have many of any overlap level
            max_count = max(overlap_counts)
            diversity_score = len(selected_tasks) - max_count
            
            if diversity_score > best_score:
                best_score = diversity_score
                best_task = candidate
        
        if best_task:
            selected_tasks.append(best_task)
        else:
            # Fallback: add random task if no improvement possible
            remaining = [t for t in all_tasks if t not in selected_tasks]
            if remaining:
                selected_tasks.append(random.choice(remaining))
    
    # Create labels
    task_labels = [f"L{t['logical_ctx']}S{t['sensory_ctx']}M{t['motor_ctx']}" 
                   for t in selected_tasks]
    
    return selected_tasks, task_labels

def stratified_stimulus_sampling(n_samples=20, seed=42):
    """
    Sample stimuli using stratified approach to ensure good coverage.
    
    Args:
        n_samples: Number of stimuli to sample per task (default: 20)
        seed: Random seed for reproducibility
        
    Returns:
        List of sampled stimulus combinations
    """
    np.random.seed(seed)
    random.seed(seed)
    
    # Generate all 256 possible stimulus combinations
    all_stimuli = list(product([0, 1], repeat=4))
    
    # Stratify by ensuring coverage across dimensions
    # Group stimuli by their patterns
    stim_groups = {}
    for stim in all_stimuli:
        # Create a pattern key - count how many dimensions are 1
        pattern = sum(stim)
        if pattern not in stim_groups:
            stim_groups[pattern] = []
        stim_groups[pattern].append(stim)
    
    # Sample proportionally from each group
    sampled_stimuli = []
    patterns = sorted(stim_groups.keys())
    
    samples_per_pattern = n_samples // len(patterns)
    remaining_samples = n_samples % len(patterns)
    
    for i, pattern in enumerate(patterns):
        n_from_pattern = samples_per_pattern
        if i < remaining_samples:
            n_from_pattern += 1
            
        if len(stim_groups[pattern]) >= n_from_pattern:
            sampled = random.sample(stim_groups[pattern], n_from_pattern)
        else:
            sampled = stim_groups[pattern]  # Take all if not enough
            
        sampled_stimuli.extend(sampled)
    
    # If we still need more samples, randomly sample from remaining
    while len(sampled_stimuli) < n_samples:
        remaining = [s for s in all_stimuli if s not in sampled_stimuli]
        if remaining:
            sampled_stimuli.append(random.choice(remaining))
        else:
            break
    
    # Ensure we have exactly n_samples
    if len(sampled_stimuli) < n_samples:
        # If still short, pad with random samples (with replacement if necessary)
        while len(sampled_stimuli) < n_samples:
            sampled_stimuli.append(random.choice(all_stimuli))
    
    # Debug info
    print(f"Debug: Requested {n_samples} stimuli, got {len(sampled_stimuli)}")
    
    return sampled_stimuli[:n_samples]

def create_unified_rsms(n_tasks=12, n_trials_per_task=20, seed=42):
    """
    Create three aligned 240x240 RSMs: task, stimulus, and response.
    
    Args:
        n_tasks: Number of tasks to select (default: 12)
        n_trials_per_task: Number of trials per task (default: 20) 
        seed: Random seed for reproducibility
        
    Returns:
        Dictionary containing all three RSMs and metadata
    """
    # Step 1: Select strategic tasks
    selected_tasks, task_labels = select_strategic_tasks(n_tasks, seed)
    print(f"Selected {len(selected_tasks)} tasks:")
    for i, (task, label) in enumerate(zip(selected_tasks, task_labels)):
        print(f"  {i}: {label} -> {task}")
    
    # Step 2: Sample stimuli for each task
    sampled_stimuli = stratified_stimulus_sampling(n_trials_per_task, seed)
    print(f"\nSampled {len(sampled_stimuli)} stimuli per task")
    
    # Step 3: Initialize CPRO environment
    if CPRO is not None:
        env = CPRO()
    else:
        env = None
        print("Warning: Using simulated responses")
    
    # Step 4: Create combined trial structure
    # Each "trial" is now a (task, stimulus) pair
    total_trials = n_tasks * n_trials_per_task
    trial_info = []
    
    for task_idx, task in enumerate(selected_tasks):
        for stim_idx, stim in enumerate(sampled_stimuli):
            trial_info.append({
                'task_idx': task_idx,
                'task': task,
                'task_label': task_labels[task_idx],
                'stimulus': stim,
                'global_idx': task_idx * n_trials_per_task + stim_idx
            })
    
    # Step 5: Compute responses for all trials
    responses = []
    for trial in trial_info:
        task = trial['task']
        stim = trial['stimulus']
        
        if env is not None:
            # Convert stimulus to CPRO format
            stims = {
                'stim1': {'VDim1': stim[0], 'VDim2': stim[1], 'ADim1': stim[2], 'ADim2': stim[3]},
                'stim2': {'VDim1': stim[0], 'VDim2': stim[1], 'ADim1': stim[2], 'ADim2': stim[3]}
            }
            
            # Evaluate rule and get response
            rule_met = env._evaluate_rule(stims, task)
            response = env._get_motor_response(task, rule_met)
        else:
            # Simulated response logic
            response = simulate_response(task, stim)
        
        responses.append(response)
    
    # Step 6: Create the three RSMs
    task_rsm = np.zeros((total_trials, total_trials))
    stimulus_rsm = np.zeros((total_trials, total_trials))
    response_rsm = np.zeros((total_trials, total_trials))
    
    # Fill RSMs
    for i in range(total_trials):
        for j in range(total_trials):
            trial_i = trial_info[i]
            trial_j = trial_info[j]
            
            # Task RSM: rule overlap between tasks
            task_i = trial_i['task']
            task_j = trial_j['task']
            same_logical = int(task_i['logical_ctx'] == task_j['logical_ctx'])
            same_sensory = int(task_i['sensory_ctx'] == task_j['sensory_ctx'])
            same_motor = int(task_i['motor_ctx'] == task_j['motor_ctx'])
            task_rsm[i, j] = same_logical + same_sensory + same_motor
            
            # Stimulus RSM: feature overlap between stimuli
            stim_i = trial_i['stimulus']
            stim_j = trial_j['stimulus']
            stim_overlap = sum(1 for dim_i, dim_j in zip(stim_i, stim_j) if dim_i == dim_j)
            stimulus_rsm[i, j] = stim_overlap
            
            # Response RSM: same response = 1, different = 0
            response_rsm[i, j] = 1 if responses[i] == responses[j] else 0
    
    # Step 7: Create block labels for visualization
    block_labels = []
    for task_idx in range(n_tasks):
        block_labels.append(task_labels[task_idx])
        
    np.save('results/analysis_outputs_fixed/task_rsm.npy', task_rsm)
    np.save('results/analysis_outputs_fixed/stimulus_rsm.npy', stimulus_rsm)
    np.save('results/analysis_outputs_fixed/response_rsm.npy', response_rsm)
    
    return {
        'task_rsm': task_rsm,
        'stimulus_rsm': stimulus_rsm, 
        'response_rsm': response_rsm,
        'task_labels': task_labels,
        'block_labels': block_labels,
        'n_tasks': n_tasks,
        'n_trials_per_task': n_trials_per_task,
        'total_trials': total_trials,
        'trial_info': trial_info,
        'responses': responses
    }

def simulate_response(task, stim):
    """
    Simulate CPRO response when environment is not available.
    
    Args:
        task: Task dictionary with logical_ctx, sensory_ctx, motor_ctx
        stim: Stimulus tuple (VDim1, VDim2, ADim1, ADim2)
    
    Returns:
        Motor response (0-3)
    """
    # Map sensory context to relevant dimension
    dim_map = {0: 1, 1: 0, 2: 2, 3: 3}  # RED->VDim2, VERTICAL->VDim1, etc.
    relevant_dim = dim_map[task['sensory_ctx']]
    
    # Check if stimulus has target feature (value = 1)
    has_feature = stim[relevant_dim] == 1
    
    # Evaluate logical rule (simplified for same stim1 and stim2)
    if task['logical_ctx'] == 0:    # AND
        rule_met = has_feature
    elif task['logical_ctx'] == 1:  # NAND  
        rule_met = not has_feature
    elif task['logical_ctx'] == 2:  # OR
        rule_met = has_feature
    else:                           # NOR
        rule_met = not has_feature
    
    # Motor response based on rule satisfaction
    if rule_met:
        return task['motor_ctx']  # Use motor context as response
    else:
        return (task['motor_ctx'] + 2) % 4  # Alternative response

def plot_block_rsm(rsm, title, filename, block_labels, n_trials_per_task, 
                   legend_text=None, cmap='viridis'):
    """
    Plot RSM with block structure and task labels.
    
    Args:
        rsm: The RSM matrix to plot
        title: Plot title
        filename: Output filename
        block_labels: Labels for each task block
        n_trials_per_task: Number of trials per task (for block boundaries)
        legend_text: Optional legend text
        cmap: Colormap to use
    """
    # Create directory if it doesn't exist
    os.makedirs('results/analysis_outputs_fixed/', exist_ok=True)
    
    # Set seaborn style with better parameters for readability
    sns.set_style("white")  # Clean background
    plt.rcParams.update({
        'font.size': 10,
        'axes.titlesize': 14,
        'axes.labelsize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 16
    })
    
    # Create figure with appropriate size
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Determine colorbar parameters
    vmin, vmax = int(rsm.min()), int(rsm.max())
    cbar_ticks = list(range(vmin, vmax + 1))
    
    # Create heatmap without individual labels
    im = ax.imshow(rsm, cmap=cmap, vmin=vmin, vmax=vmax, aspect='equal')
    
    # Add colorbar with proper styling
    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.set_label('Similarity', fontsize=32)
    cbar.set_ticks(cbar_ticks)
    cbar.ax.tick_params(labelsize=26)
    
    # Add block boundaries
    n_tasks = len(block_labels)
    block_size = n_trials_per_task
    
    # Draw block boundary lines
    for i in range(1, n_tasks):
        pos = i * block_size - 0.5
        ax.axhline(y=pos, color='gold', linewidth=4, alpha=0.8)
        ax.axvline(x=pos, color='gold', linewidth=4, alpha=0.8)
    
    # Set up ticks for block centers
    block_centers = [i * block_size + block_size / 2 - 0.5 for i in range(n_tasks)]
    
    # Set ticks and labels
    ax.set_xticks(block_centers)
    ax.set_yticks(block_centers)
    ax.set_xticklabels(block_labels, rotation=45, ha='right', fontsize=26)
    ax.set_yticklabels(block_labels, rotation=0, fontsize=26)
    
    # Set title and axis labels with proper spacing
    ax.set_title(title, fontsize=40, pad=20)
    ax.set_xlabel('Context Blocks', fontsize=32, labelpad=10)
    ax.set_ylabel('Context Blocks', fontsize=32, labelpad=10)
    
    # Remove top and right spines for cleaner look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    
    # Add legend if provided
    if legend_text:
        # Create text box for legend
        legend_props = dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8)
        ax.text(1.02, 0.98, legend_text, transform=ax.transAxes, fontsize=9,
                verticalalignment='top', bbox=legend_props)
    
    # Adjust layout to prevent overlapping
    plt.tight_layout()
    
    # Save figure
    full_filename = f'results/analysis_outputs_fixed/{filename}'
    plt.savefig(full_filename, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    print(f"Saved {filename}")
    print(f"RSM shape: {rsm.shape}")
    print(f"Value range: {rsm.min():.2f} to {rsm.max():.2f}")
    print(f"Mean similarity: {rsm.mean():.2f}")
    print("-" * 50)

def main():
    """Generate all three unified theoretical RSMs."""
    print("Generating Unified Theoretical RSMs")
    print("=" * 60)
    
    # Create unified RSMs with smaller dimensions
    results = create_unified_rsms(n_tasks=8, n_trials_per_task=10, seed=42)
    
    # Plot Task RSM
    print("\n1. Creating Task RSM...")
    plot_block_rsm(results['task_rsm'], 
                   f'Context RSM (Rule Overlap)\n{results["n_tasks"]} Example Contexts × {results["n_trials_per_task"]} Random Trials',
                   'unified_context_rsm.png',
                   results['block_labels'],
                   results['n_trials_per_task'],
                   cmap='Reds')
    
    # Plot Stimulus RSM  
    print("\n2. Creating Stimulus RSM...")
    plot_block_rsm(results['stimulus_rsm'],
                   f'Stimulus RSM (Feature Overlap)\n{results["n_tasks"]} Example Contexts × {results["n_trials_per_task"]} Random Trials', 
                   'unified_stimulus_rsm.png',
                   results['block_labels'],
                   results['n_trials_per_task'],
                   cmap='Blues')
    
    # Plot Response RSM
    print("\n3. Creating Response RSM...")
    plot_block_rsm(results['response_rsm'],
                   f'Response RSM (Same/Different Response)\n{results["n_tasks"]} Example Contexts × {results["n_trials_per_task"]} Random Trials',
                   'unified_response_rsm.png', 
                   results['block_labels'],
                   results['n_trials_per_task'],
                   cmap='Greens')
    
    print("\nAll unified RSMs generated successfully!")
    print("Files saved in: results/analysis_outputs_fixed/")
    print("  - unified_task_rsm.png")
    print("  - unified_stimulus_rsm.png") 
    print("  - unified_response_rsm.png")
    
    # Print summary statistics
    print(f"\nSummary Statistics:")
    print(f"Total matrix size: {results['total_trials']}×{results['total_trials']}")
    print(f"Block structure: {results['n_tasks']} blocks of {results['n_trials_per_task']}×{results['n_trials_per_task']}")
    print(f"Selected tasks: {results['task_labels']}")
    
#     # Compute dimensionalities
#     dimensionalities = compute_rsm_dimensionalities(results)
    
#     # Add dimensionalities to results dictionary
#     results['dimensionalities'] = dimensionalities
    
#     print(f"\nDimensionality Summary:")
#     print(f"Task RSM dimensionality: {dimensionalities['task_rsm']:.4f}")
#     print(f"Stimulus RSM dimensionality: {dimensionalities['stimulus_rsm']:.4f}")  
#     print(f"Response RSM dimensionality: {dimensionalities['response_rsm']:.4f}")
    
    return results

if __name__ == "__main__":
    results = main()