import numpy as np
from pomcpow.pomcpow import POMCPOW as BasePOMCPOW

class POMCPOWAdapter:
    """
    # & Adapter class for POMCPOW that handles the 20D state space of the Kidnapped Robot environment
    # & while using the existing POMCPOW implementation internally.
    """
    
    def __init__(self, action_space, n_particles=100, max_depth=5, n_simulations=100,
                exploration_const=50.0, alpha_action=0.5, k_action=4.0,
                alpha_obs=0.5, k_obs=4.0, discount_factor=0.95):
        """
        # & Initialize POMCPOWAdapter for the Kidnapped Robot environment
        """
        # Create the base POMCPOW instance
        self.base_pomcpow = BasePOMCPOW(
            action_space=action_space,
            n_particles=n_particles,
            max_depth=max_depth,
            n_simulations=n_simulations,
            exploration_const=exploration_const,
            alpha_action=alpha_action,
            k_action=k_action,
            alpha_obs=alpha_obs,
            k_obs=k_obs,
            discount_factor=discount_factor
        )
        
        # State dimension for Kidnapped Robot
        self.state_dim = 20
        
        # Initialize particle set for 20D states
        self.particles = self._init_particles(n_particles)
        
    def _init_particles(self, n_particles):
        """Initialize particles for the Kidnapped Robot environment"""
        # Create particles with proper 20D structure
        particles = np.zeros((n_particles, self.state_dim))
        
        # Position (uniformly distributed in map)
        particles[:, 0:2] = np.random.uniform(0, 20, (n_particles, 2))
        
        # Orientation (uniformly distributed in [0, 2π))
        particles[:, 2] = np.random.uniform(0, 2*np.pi, n_particles)
        
        # Velocity and steering (normal distribution)
        particles[:, 3:5] = np.random.normal(0, 0.5, (n_particles, 2))
        
        # Sensor calibration (normal distribution)
        particles[:, 5:10] = np.random.normal(0, 0.1, (n_particles, 5))
        
        # Feature descriptors (normalized random vectors)
        features = np.random.normal(0, 1, (n_particles, 10))
        for i in range(n_particles):
            norm = np.linalg.norm(features[i])
            if norm > 0:
                features[i] = features[i] / norm
        particles[:, 10:20] = features
        
        return particles
    
    def update(self, action, observation, transition_model=None, observation_model=None):
        """
        # & Update belief with new action and observation for Kidnapped Robot environment
        """
        # Store the original models
        self.transition_model = transition_model
        self.observation_model = observation_model
        
        # Create wrapped models that adapt 20D to 10D for the base POMCPOW
        adapted_transition_model = self._adapt_transition_model()
        adapted_observation_model = self._adapt_observation_model()
        
        # Update the base POMCPOW with adapted models
        try:
            self.base_pomcpow.update(
                action, 
                observation, 
                adapted_transition_model, 
                adapted_observation_model
            )
            
            # Get the updated particles from base POMCPOW (10D)
            base_particles = self.base_pomcpow.get_belief_estimate()
            
            # Expand the 10D particles back to 20D while preserving the shared dimensions
            self._update_from_base_particles(base_particles)
            
        except Exception as e:
            print(f"POMCPOW update error: {e}")
            # If update fails, add noise to current particles
            self._noisy_update(action)
    
    def _adapt_transition_model(self):
        """Create a transition model adapter for the base POMCPOW"""
        original_model = self.transition_model
        
        def adapted_transition(state_10d, action):
            # Create a dummy 20D state by padding the 10D state
            state_20d = self._expand_state(state_10d)
            
            # Apply the original transition model to the 20D state
            if original_model is not None:
                next_state_20d = original_model(state_20d, action)
            else:
                # Simple transition if no model provided
                next_state_20d = state_20d.copy()
                # Position update based on action
                if action == 0:  # Forward
                    next_state_20d[0] += np.cos(next_state_20d[2]) * 0.1
                    next_state_20d[1] += np.sin(next_state_20d[2]) * 0.1
                elif action == 1:  # Left
                    next_state_20d[2] += 0.1
                elif action == 2:  # Right
                    next_state_20d[2] -= 0.1
                
                # Add noise
                next_state_20d[:5] += np.random.normal(0, 0.01, 5)
            
            # Compress back to 10D for base POMCPOW
            return self._compress_state(next_state_20d)
        
        return adapted_transition
    
    def _adapt_observation_model(self):
        """Create an observation model adapter for the base POMCPOW"""
        original_model = self.observation_model
        
        def adapted_observation(state_10d, observation):
            # Create a dummy 20D state by padding the 10D state
            state_20d = self._expand_state(state_10d)
            
            # Apply the original observation model to the 20D state
            if original_model is not None:
                return original_model(state_20d, observation)
            else:
                # Simple likelihood model if none provided
                return 1.0
        
        return adapted_observation
    
    def _expand_state(self, state_10d):
        """Expand a 10D state to a 20D state using saved particle information"""
        # Create a new 20D state
        state_20d = np.zeros(20)
        
        # Copy the available dimensions (minimum of what we have)
        min_dims = min(10, len(state_10d))
        state_20d[:min_dims] = state_10d[:min_dims]
        
        # Use average values from existing particles for the remaining dimensions
        if len(self.particles) > 0:
            mean_particle = np.mean(self.particles, axis=0)
            state_20d[min_dims:] = mean_particle[min_dims:]
        
        return state_20d
    
    def _compress_state(self, state_20d):
        """Compress a 20D state to a 10D state for base POMCPOW"""
        # Take the first 10 dimensions
        # This preserves position, orientation, and some of the other variables
        return state_20d[:10]
    
    def _update_from_base_particles(self, base_particles):
        """Update 20D particles based on 10D particles from base POMCPOW"""
        n_particles = len(base_particles)
        
        # Create new 20D particles
        new_particles = np.zeros((n_particles, self.state_dim))
        
        for i in range(n_particles):
            # Expand each 10D particle to 20D
            new_particles[i] = self._expand_state(base_particles[i])
            
            # Add some noise to maintain diversity in extra dimensions
            new_particles[i, 10:] += np.random.normal(0, 0.01, 10)
            
            # Normalize feature descriptors
            feature_norm = np.linalg.norm(new_particles[i, 10:20])
            if feature_norm > 0:
                new_particles[i, 10:20] = new_particles[i, 10:20] / feature_norm
        
        self.particles = new_particles
    
    def _noisy_update(self, action):
        """Fallback update method that adds noise to current particles"""
        # Add noise to positions based on action
        if action == 0:  # Forward
            for i in range(len(self.particles)):
                theta = self.particles[i, 2]
                self.particles[i, 0] += np.cos(theta) * 0.1
                self.particles[i, 1] += np.sin(theta) * 0.1
        elif action == 1:  # Left
            self.particles[:, 2] += 0.1
        elif action == 2:  # Right
            self.particles[:, 2] -= 0.1
        
        # Add noise to all dimensions
        self.particles[:, 0:2] += np.random.normal(0, 0.05, (len(self.particles), 2))  # Position
        self.particles[:, 2] += np.random.normal(0, 0.02, len(self.particles))  # Orientation
        self.particles[:, 3:5] += np.random.normal(0, 0.01, (len(self.particles), 2))  # Velocity
        self.particles[:, 5:10] += np.random.normal(0, 0.005, (len(self.particles), 5))  # Calibration
        
        # Normalize feature descriptors
        for i in range(len(self.particles)):
            feature_norm = np.linalg.norm(self.particles[i, 10:20])
            if feature_norm > 0:
                self.particles[i, 10:20] = self.particles[i, 10:20] / feature_norm
    
    def get_belief_estimate(self):
        """Return the current belief estimate (20D particles)"""
        return self.particles
    
    def select_action(self, observation=None):
        """Select an action based on current belief"""
        # Use the base POMCPOW's selection if possible
        try:
            return self.base_pomcpow.select_action(observation)
        except:
            # Fallback: Simple rule-based policy for Kidnapped Robot
            # Calculate mean position
            mean_pos = np.mean(self.particles[:, 0:2], axis=0)
            
            # Calculate center of map
            map_center = np.array([10, 10])
            
            # Direction to center
            direction = map_center - mean_pos
            
            # Select action based on direction
            if np.linalg.norm(direction) < 1.0:
                return 3  # Stay if close to center
            
            angle = np.arctan2(direction[1], direction[0])
            mean_orientation = np.mean(self.particles[:, 2]) % (2*np.pi)
            
            # Calculate angle difference
            angle_diff = (angle - mean_orientation + np.pi) % (2*np.pi) - np.pi
            
            if abs(angle_diff) < 0.2:
                return 0  # Forward
            elif angle_diff > 0:
                return 1  # Turn left
            else:
                return 2  # Turn right
    
    def reset(self):
        """Reset the belief to initial state"""
        self.base_pomcpow.reset()
        self.particles = self._init_particles(len(self.particles))
