"""
CPRO (Concrete Permuted Rules Operation) Environment

This module provides the environment for the CPRO task, including:
- CPROConfig: Configuration dataclass for CPRO parameters
- CPRO: Main environment class implementing the task logic
"""

import torch
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
import random

@dataclass
class CPROConfig:
    """Configuration for CPRO environment."""
    dt: int = 100
    sigma: float = 0.5
    input_size: int = 20
    output_size: int = 4  # 4 choices (no fixation as response)

class CPRO:
    """
    Concrete Permuted Rules Operation task environment.
    
    Key Features:
    - 4 stimulus dimensions (VDim1, VDim2, ADim1, ADim2)
    - Each dimension has 2 possible values
    - 4 dim x 2 val = 8 dedicated input nodes, reused between Stim1 and Stim2
    - Logical rules: AND, NAND, OR, NOR
    - Sensory contexts: RED, VERTICAL, HI-PITCH, CONSTANT
    - Motor contexts: LIND, LMID, RIND, RMID
    """
    def __init__(self, config: Optional[CPROConfig] = None, training_mode: str = "minimal"):
        self.config = config or CPROConfig()
        self.training_mode = training_mode

        self.all_tasks = [
            {'logical_ctx': l, 'sensory_ctx': s, 'motor_ctx': m}
            for l in range(4) for s in range(4) for m in range(4)
        ]
        self.diagonal_tasks = [
            {'logical_ctx': i, 'sensory_ctx': i, 'motor_ctx': i} 
            for i in range(4)
        ]

        self._blocks = self._generate_balanced_blocks() if "balanced" in training_mode else None
        self.training_tasks = self._get_training_tasks()
        self.test_tasks = self._get_test_tasks()
        
        # Define contexts
        self.logical_rules = ['AND', 'NAND', 'OR', 'NOR']
        self.sensory_rules = ['RED', 'VERTICAL', 'HI-PITCH', 'CONSTANT']
        self.motor_rules = ['LIND', 'LMID', 'RIND', 'RMID']
        
        # Generate all possible stimulus combinations
        self.all_stim_combinations = self._generate_all_stim_combinations()
        
    def _generate_balanced_blocks(self):
        """
        Generate blocks of tasks ensuring balanced rule frequencies with randomized combinations.
        Each rule (L,S,M) appears equally often but combinations are not structured.
        """
        all_tasks = self.all_tasks.copy()
        random.shuffle(all_tasks)  # Randomize initial order
        blocks = []
        tasks_per_block = 16
        
        def select_tasks_for_block(available_tasks):
            """Select tasks while maintaining rule balance within a block."""
            selected = []
            rule_counts = {'L': [0]*4, 'S': [0]*4, 'M': [0]*4}
            target_count = tasks_per_block // 4  # Each rule should appear 4 times
            
            # Keep selecting until we have enough tasks
            while len(selected) < tasks_per_block and available_tasks:
                # Shuffle remaining tasks to avoid systematic bias
                random.shuffle(available_tasks)
                
                found_valid_task = False
                for task in available_tasks[:]:  # Copy for safe iteration
                    l, s, m = task['logical_ctx'], task['sensory_ctx'], task['motor_ctx']
                    
                    # Check if adding this task maintains balance
                    if (rule_counts['L'][l] < target_count and 
                        rule_counts['S'][s] < target_count and 
                        rule_counts['M'][m] < target_count):
                        
                        selected.append(task)
                        available_tasks.remove(task)
                        rule_counts['L'][l] += 1
                        rule_counts['S'][s] += 1
                        rule_counts['M'][m] += 1
                        found_valid_task = True
                        break
                
                if not found_valid_task:
                    break  # No valid tasks found, exit loop
                    
            return selected
        
        # Generate 4 blocks of 16 tasks each
        remaining_tasks = all_tasks.copy()
        for _ in range(4):
            block = select_tasks_for_block(remaining_tasks)
            if block:  # Only add non-empty blocks
                blocks.append(block)
        
        return blocks
    
    def _get_training_tasks(self):
        if self.training_mode == "minimal":
            return self.diagonal_tasks
        elif self.training_mode == "maximal":
            return [task for task in self.all_tasks if task not in self.diagonal_tasks]
        else:  # balanced modes
            num_blocks = int(self.training_mode.split('_')[1]) // 16
            return [task for block in self._blocks[:num_blocks] for task in block]
    
    def _get_test_tasks(self):
        return self.all_tasks # testing on all tasks in all cases
    
    def _generate_all_stim_combinations(self) -> List[Dict]:
        """Generate all 256 possible stimulus combinations."""
        combinations = []
        for s1_v1 in [0, 1]:
            for s1_v2 in [0, 1]:
                for s1_a1 in [0, 1]:
                    for s1_a2 in [0, 1]:
                        for s2_v1 in [0, 1]:
                            for s2_v2 in [0, 1]:
                                for s2_a1 in [0, 1]:
                                    for s2_a2 in [0, 1]:
                                        combinations.append({
                                            'stim1': {
                                                'VDim1': s1_v1,
                                                'VDim2': s1_v2,
                                                'ADim1': s1_a1,
                                                'ADim2': s1_a2
                                            },
                                            'stim2': {
                                                'VDim1': s2_v1,
                                                'VDim2': s2_v2,
                                                'ADim1': s2_a1,
                                                'ADim2': s2_a2
                                            }
                                        })
        return combinations

    def _evaluate_rule(self, stims: Dict, task: Dict) -> bool:
        """Evaluate if rule is met for given stimuli and task."""
        # Get relevant stimulus dimension based on sensory context
        dim_map = {
            0: 'VDim1',   # RED
            1: 'VDim2',   # VERTICAL
            2: 'ADim1',   # HI-PITCH
            3: 'ADim2'    # CONSTANT
        }
        dim = dim_map[task['sensory_ctx']]
        
        # Get values for relevant dimension
        val1 = stims['stim1'][dim]
        val2 = stims['stim2'][dim]
        
        # Target value is 1 for all sensory rules
        # (e.g., 1 = RED, VERTICAL, HI-PITCH, or CONSTANT)
        target = 1
        
        # Evaluate logical rule
        if task['logical_ctx'] == 0:  # AND
            return val1 == target and val2 == target
        elif task['logical_ctx'] == 1:  # NAND
            return not (val1 == target and val2 == target)
        elif task['logical_ctx'] == 2:  # OR
            return val1 == target or val2 == target
        else:  # NOR
            return not (val1 == target or val2 == target)

    def _get_motor_response(self, task: Dict, rule_met: bool) -> int:
        """Get appropriate motor response based on task and rule evaluation."""
        motor_ctx = task['motor_ctx']
        
        # Convert motor context to response
        if motor_ctx < 2:  # Left hand
            if rule_met:
                return motor_ctx  # LIND(0) or LMID(1)
            else:
                return 1 - motor_ctx  # Other finger of left hand
        else:  # Right hand
            if rule_met:
                return motor_ctx  # RIND(2) or RMID(3)
            else:
                return 5 - motor_ctx # Other finger of right hand

    def _convert_stim_to_input(self, stim: Dict) -> List[int]:
        """Get indices for one-hot encoding of stimulus."""
        indices = []
        for dim_idx, (dim, val) in enumerate(stim.items()):
            # Each dimension gets 2 positions, val determines which one is active
            pos = dim_idx * 2 + val
            indices.append(pos)
        return indices

    def generate_batch(self, batch_size: int, training: bool = True, batch_idx: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate a batch of specified size.
        Args:
            batch_size: Size of batch to generate
            training: If True, use training tasks, else test tasks
            batch_idx: Which batch to generate (0-based index)
        """
        # Get total dataset
        all_tasks = []
        all_stims = []
        tasks_list = self.training_tasks if training else self.test_tasks
        for task in tasks_list:
            all_tasks.extend([task] * 256)
            all_stims.extend(self.all_stim_combinations)
        
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        batch_tasks = all_tasks[start_idx:end_idx]
        batch_stims = all_stims[start_idx:end_idx]
        
        # Generate inputs and labels as before
        inputs = torch.zeros(10, batch_size, self.config.input_size)
        labels = torch.zeros(10, batch_size, dtype=torch.long)
        
        # Fill batch
        for i, (task, stims) in enumerate(zip(batch_tasks, batch_stims)):
            # Set context inputs (one-hot encoded)
            inputs[:, i, 8 + task['logical_ctx']] = 1  # Logical context
            inputs[:, i, 12 + task['sensory_ctx']] = 1  # Sensory context
            inputs[:, i, 16 + task['motor_ctx']] = 1  # Motor context
            
            # Convert stimuli to input indices (input channel indices: 0 to 7)
            stim1_indices = self._convert_stim_to_input(stims['stim1'])
            stim2_indices = self._convert_stim_to_input(stims['stim2'])
            
            # Set inputs (one-hot encoded)
            for idx in stim1_indices:
                inputs[1, i, idx] = 1  # First stimulus period
                inputs[2, i, idx] = 1 
            for idx in stim2_indices:
                inputs[5, i, idx] = 1  # Second stimulus period
                inputs[6, i, idx] = 1
            
            # # Add noise
            # inputs[:, i, :8] += torch.randn(8) * self.config.sigma
            
            # Determine correct response
            rule_met = self._evaluate_rule(stims, task)
            response = self._get_motor_response(task, rule_met)
            
            # Set target (only at decision period)
            labels[9, i] = response
        
        return inputs, labels

    def get_task_overlap(self, test_task: Dict) -> int:
        """Calculate number of overlapping rules with training tasks."""
        max_overlap = 0
        for train_task in self.training_tasks:
            overlap = sum(1 for k in train_task if train_task[k] == test_task[k])
            max_overlap = max(max_overlap, overlap)
        return max_overlap

    def categorize_test_tasks(self) -> Dict[int, List[Dict]]:
        """
        Categorize test tasks by overlap with training tasks.
        Training tasks themselves will have overlap=3 (maximum overlap).
        """
        categorized = {1: [], 2: [], 3: []}
        for task in self.test_tasks:
            overlap = self.get_task_overlap(task)
            categorized[overlap].append(task)
        
        # Also add training tasks to overlap=3 category if not already there
        for task in self.training_tasks:
            if task not in categorized[3]:
                categorized[3].append(task)
        return categorized

__all__ = ['CPRO', 'CPROConfig']