import numpy as np
from pomcpow.pomcpow import POMCPOW as BasePOMCPOW

class POMCPOWAdapterMTT:
    """
    # & Adapter class for POMCPOW that handles the 20D state space of the Multi-Target Tracking environment
    # & while using the existing POMCPOW implementation internally.
    """
    
    def __init__(self, action_space, n_particles=100, max_depth=5, n_simulations=100,
                exploration_const=50.0, alpha_action=0.5, k_action=4.0,
                alpha_obs=0.5, k_obs=4.0, discount_factor=0.95):
        """
        # & Initialize POMCPOWAdapter for the Multi-Target Tracking environment
        """
        # Create the base POMCPOW instance
        self.base_pomcpow = BasePOMCPOW(
            action_space=action_space,
            n_particles=n_particles,
            max_depth=max_depth,
            n_simulations=n_simulations,
            exploration_const=exploration_const,
            alpha_action=alpha_action,
            k_action=k_action,
            alpha_obs=alpha_obs,
            k_obs=k_obs,
            discount_factor=discount_factor
        )
        
        # State dimension for Multi-Target Tracking (agent + 4 targets)
        self.state_dim = 20
        
        # Initialize particle set for 20D states
        self.particles = self._init_particles(n_particles)
        
    def _init_particles(self, n_particles):
        """Initialize particles for the Multi-Target Tracking environment"""
        # Create particles with proper 20D structure
        particles = np.zeros((n_particles, self.state_dim))
        
        # Agent position (uniformly distributed in map)
        particles[:, 0:2] = np.random.uniform(0, 10, (n_particles, 2))
        
        # Agent velocity (normal distribution around zero)
        particles[:, 2:4] = np.random.normal(0, 0.2, (n_particles, 2))
        
        # Target positions and velocities (4 targets)
        for i in range(4):
            pos_idx = 4 + 4*i
            vel_idx = pos_idx + 2
            
            # Position (uniformly distributed in map)
            particles[:, pos_idx:pos_idx+2] = np.random.uniform(0, 10, (n_particles, 2))
            
            # Velocity (normal distribution around zero)
            particles[:, vel_idx:vel_idx+2] = np.random.normal(0, 0.2, (n_particles, 2))
        
        return particles
    
    def update(self, action, observation, transition_model=None, observation_model=None):
        """
        # & Update belief with new action and observation for Multi-Target Tracking environment
        """
        # Store the original models
        self.transition_model = transition_model
        self.observation_model = observation_model
        
        # Create wrapped models that adapt 20D to 10D for the base POMCPOW
        adapted_transition_model = self._adapt_transition_model()
        adapted_observation_model = self._adapt_observation_model()
        
        # Update the base POMCPOW with adapted models
        try:
            self.base_pomcpow.update(
                action, 
                observation, 
                adapted_transition_model, 
                adapted_observation_model
            )
            
            # Get the updated particles from base POMCPOW (10D)
            base_particles = self.base_pomcpow.get_belief_estimate()
            
            # Expand the 10D particles back to 20D while preserving the shared dimensions
            self._update_from_base_particles(base_particles)
            
        except Exception as e:
            print(f"POMCPOW update error: {e}")
            # If update fails, add noise to current particles
            self._noisy_update(action)
    
    def _adapt_transition_model(self):
        """Create a transition model adapter for the base POMCPOW"""
        original_model = self.transition_model
        
        def adapted_transition(state_10d, action):
            # Create a dummy 20D state by padding the 10D state
            state_20d = self._expand_state(state_10d)
            
            # Apply the original transition model to the 20D state
            if original_model is not None:
                next_state_20d = original_model(state_20d, action)
            else:
                # Simple transition if no model provided
                next_state_20d = state_20d.copy()
                
                # Agent position update based on velocity
                next_state_20d[0:2] += next_state_20d[2:4] * 0.1
                
                # Add noise to agent velocity
                next_state_20d[2:4] += np.random.normal(0, 0.01, 2)
                
                # Update all targets
                for i in range(4):
                    pos_idx = 4 + 4*i
                    vel_idx = pos_idx + 2
                    
                    # Position update based on velocity
                    next_state_20d[pos_idx:pos_idx+2] += next_state_20d[vel_idx:vel_idx+2] * 0.1
                    
                    # Add noise to target velocity
                    next_state_20d[vel_idx:vel_idx+2] += np.random.normal(0, 0.01, 2)
            
            # Compress back to 10D for base POMCPOW
            return self._compress_state(next_state_20d)
        
        return adapted_transition
    
    def _adapt_observation_model(self):
        """Create an observation model adapter for the base POMCPOW"""
        original_model = self.observation_model
        
        def adapted_observation(state_10d, observation):
            # Create a dummy 20D state by padding the 10D state
            state_20d = self._expand_state(state_10d)
            
            # Apply the original observation model to the 20D state
            if original_model is not None:
                return original_model(state_20d, observation)
            else:
                # Simple likelihood model if none provided
                return 1.0
        
        return adapted_observation
    
    def _expand_state(self, state_10d):
        """Expand a 10D state to a 20D state using saved particle information"""
        # Create a new 20D state
        state_20d = np.zeros(20)
        
        # Copy the available dimensions (minimum of what we have)
        min_dims = min(10, len(state_10d))
        state_20d[:min_dims] = state_10d[:min_dims]
        
        # Use average values from existing particles for the remaining dimensions
        if len(self.particles) > 0:
            mean_particle = np.mean(self.particles, axis=0)
            state_20d[min_dims:] = mean_particle[min_dims:]
        
        return state_20d
    
    def _compress_state(self, state_20d):
        """Compress a 20D state to a 10D state for base POMCPOW by focusing on agent and first target"""
        # Take agent state (4D) and first target state (4D)
        state_10d = np.zeros(10)
        
        # Copy agent position and velocity (4D)
        state_10d[0:4] = state_20d[0:4]
        
        # Copy first target position and velocity (4D)
        state_10d[4:8] = state_20d[4:8]
        
        # Add normalized information about other targets (2D)
        # This helps preserve knowledge about overall target distribution
        if len(state_20d) >= 20:
            # Average positions of other targets
            other_targets_x = np.mean(state_20d[[8, 12, 16]], axis=0)
            other_targets_y = np.mean(state_20d[[10, 14, 18]], axis=0)
            state_10d[8:10] = [other_targets_x, other_targets_y]
        
        return state_10d
    
    def _update_from_base_particles(self, base_particles):
        """Update 20D particles based on 10D particles from base POMCPOW"""
        n_particles = len(base_particles)
        
        # Create new 20D particles
        new_particles = np.zeros((n_particles, self.state_dim))
        
        for i in range(n_particles):
            # Expand each 10D particle to 20D
            new_particles[i] = self._expand_state(base_particles[i])
            
            # Add some noise to maintain diversity in extra dimensions
            for j in range(1, 4):  # For targets 2, 3, 4
                pos_idx = 4 + 4*j
                vel_idx = pos_idx + 2
                
                # Position noise
                new_particles[i, pos_idx:pos_idx+2] += np.random.normal(0, 0.05, 2)
                
                # Velocity noise
                new_particles[i, vel_idx:vel_idx+2] += np.random.normal(0, 0.01, 2)
        
        self.particles = new_particles
    
    def _noisy_update(self, action):
        """Fallback update method that adds noise to current particles"""
        # Update agent position based on velocity
        self.particles[:, 0:2] += self.particles[:, 2:4] * 0.1
        
        # Add noise to agent state
        self.particles[:, 0:2] += np.random.normal(0, 0.05, (len(self.particles), 2))  # Position
        self.particles[:, 2:4] += np.random.normal(0, 0.01, (len(self.particles), 2))  # Velocity
        
        # Update all targets
        for i in range(4):
            pos_idx = 4 + 4*i
            vel_idx = pos_idx + 2
            
            # Position update based on velocity
            self.particles[:, pos_idx:pos_idx+2] += self.particles[:, vel_idx:vel_idx+2] * 0.1
            
            # Add noise
            self.particles[:, pos_idx:pos_idx+2] += np.random.normal(0, 0.05, (len(self.particles), 2))
            self.particles[:, vel_idx:vel_idx+2] += np.random.normal(0, 0.01, (len(self.particles), 2))
    
    def get_belief_estimate(self):
        """Return the current belief estimate (20D particles)"""
        return self.particles
    
    def select_action(self, observation=None):
        """Select an action based on current belief"""
        # For Multi-Target Tracking, use a simple goal-directed policy
        try:
            # Calculate mean agent position
            mean_agent_pos = np.mean(self.particles[:, 0:2], axis=0)
            
            # Calculate mean positions of all targets
            mean_target_positions = []
            for i in range(4):
                pos_idx = 4 + 4*i
                target_pos = np.mean(self.particles[:, pos_idx:pos_idx+2], axis=0)
                mean_target_positions.append(target_pos)
            
            # Find closest target
            distances = [np.linalg.norm(mean_agent_pos - target_pos) for target_pos in mean_target_positions]
            closest_idx = np.argmin(distances)
            closest_target = mean_target_positions[closest_idx]
            
            # Direction to closest target
            direction = closest_target - mean_agent_pos
            
            # Discretize direction into actions
            # Assuming actions are: 0=up, 1=right, 2=down, 3=left
            if abs(direction[1]) > abs(direction[0]):
                # More significant movement in y-direction
                return 0 if direction[1] > 0 else 2
            else:
                # More significant movement in x-direction
                return 1 if direction[0] > 0 else 3
                
        except Exception as e:
            print(f"Error in POMCPOW select_action: {e}")
            # Random action as fallback
            return np.random.randint(0, 4)
    
    def reset(self):
        """Reset the belief to initial state"""
        self.base_pomcpow.reset()
        self.particles = self._init_particles(len(self.particles))

