import numpy as np
from sklearn.cluster import DBSCAN

class POMCPOW:
    """
    # & Implementation of Partially Observable Monte Carlo Planning with Observation Widening (POMCPOW)
    # & adapted for the Light-Dark 10D environment.
    # & 
    # & This version is specifically tailored to handle the 10D state space (5D position + 5D velocity)
    # & of the Light-Dark environment.
    """
    
    def __init__(self, action_space, n_particles=100, max_depth=3, n_simulations=50,
                exploration_const=10.0, alpha_action=0.5, k_action=4.0,
                alpha_obs=0.5, k_obs=4.0, discount_factor=0.95, verbose=False):
        """
        # & Initialize POMCPOW
        # &
        # & Args:
        # &    action_space: List of available actions
        # &    n_particles: Number of particles to maintain in the belief
        # &    max_depth: Maximum planning depth
        # &    n_simulations: Number of simulations per action selection
        # &    exploration_const: Exploration constant for UCB
        # &    alpha_action, k_action: Parameters for progressive widening on actions
        # &    alpha_obs, k_obs: Parameters for progressive widening on observations
        # &    discount_factor: Discount factor for future rewards
        # &    verbose: Whether to print debug information
        """
        # Algorithm parameters
        self.n_particles = n_particles
        self.max_depth = max_depth
        self.n_simulations = n_simulations
        self.exploration_const = exploration_const
        self.alpha_action = alpha_action
        self.k_action = k_action
        self.alpha_obs = alpha_obs
        self.k_obs = k_obs
        self.discount_factor = discount_factor
        self.verbose = verbose
        
        # Action space - for Light-Dark 10D, this is typically 10 actions
        # (positive and negative force in each of the 5 dimensions)
        if hasattr(action_space, '__len__'):
            self.action_space = list(action_space)
        elif hasattr(action_space, 'n'):
            self.action_space = list(range(action_space.n))
        else:
            self.action_space = list(range(10))  # Default for Light-Dark 10D
            
        # Initialize state representation - strictly 10D for the Light-Dark environment
        # 5D position + 5D velocity
        self.state_dim = 10
        self.particles = self._init_particles()
        
        # For storing the transition and observation models
        self.transition_model = None
        self.observation_model = None
        
        # Track most recent observation for belief updates
        self.last_observation = None
    
    def _init_particles(self):
        """Initialize particles for the belief state"""
        particles = np.zeros((self.n_particles, self.state_dim))
        
        # Initialize positions uniformly in map
        particles[:, :5] = np.random.uniform(0, 10, (self.n_particles, 5))
        
        # Initialize velocities with small random values
        particles[:, 5:] = np.random.normal(0, 0.1, (self.n_particles, 5))
        
        return particles
    
    def update(self, action, observation, transition_model=None, observation_model=None):
        """
        # & Update belief with new action and observation
        # &
        # & Args:
        # &    action: Action taken
        # &    observation: Observation received
        # &    transition_model: Function that takes (state, action) and returns next state
        # &    observation_model: Function that takes (state, observation) and returns likelihood
        """
        # Store the models if provided
        if transition_model is not None:
            self.transition_model = transition_model
        
        if observation_model is not None:
            self.observation_model = observation_model
        
        # Store the observation
        self.last_observation = observation
        
        # Update particles using SIR
        self.particles = self._sir_update(action, observation)
    
    def _sir_update(self, action, observation):
        """
        # & Sequential Importance Resampling update for beliefs
        # &
        # & Args:
        # &    action: Action taken
        # &    observation: Observation received
        # &
        # & Returns:
        # &    Updated particles
        """
        if self.transition_model is None or self.observation_model is None:
            # If models aren't set, add noise to current particles
            new_particles = self.particles.copy()
            new_particles[:, :5] += np.random.normal(0, 0.2, (self.n_particles, 5))  # Position noise
            new_particles[:, 5:] += np.random.normal(0, 0.1, (self.n_particles, 5))  # Velocity noise
            
            # Ensure positions are within bounds (0-10)
            new_particles[:, :5] = np.clip(new_particles[:, :5], 0, 10)
            
            return new_particles
        
        # Initialize new particle set
        new_particles = np.zeros_like(self.particles)
        
        try:
            # Apply transition model to each particle
            for i in range(self.n_particles):
                try:
                    # Apply state transition model - specifically handling the 10D state
                    next_state = self._apply_transition(self.particles[i], action)
                    new_particles[i] = next_state
                except Exception as e:
                    if self.verbose:
                        print(f"Error in transition model: {e}")
                    # Fall back to simple transition
                    new_particles[i] = self._simple_transition(self.particles[i], action)
            
            # Calculate weights using observation model
            weights = np.zeros(self.n_particles)
            
            for i in range(self.n_particles):
                try:
                    # Calculate observation likelihood
                    # In Light-Dark, observation is only the position component (5D),
                    # but our state is 10D (position + velocity)
                    weights[i] = self.observation_model(new_particles[i], observation)
                except Exception as e:
                    if self.verbose:
                        print(f"Error in observation model: {e}")
                    # Default to a small positive weight
                    weights[i] = 1e-10
            
            # Ensure weights are positive
            weights = np.maximum(weights, 1e-10)
            
            # Normalize weights
            weights_sum = np.sum(weights)
            if weights_sum > 0:
                weights = weights / weights_sum
            else:
                # Equal weights if all weights are very small
                weights = np.ones(self.n_particles) / self.n_particles
            
            # Resample particles
            try:
                indices = np.random.choice(self.n_particles, self.n_particles, p=weights, replace=True)
                resampled_particles = new_particles[indices]
            except Exception as e:
                if self.verbose:
                    print(f"Error in resampling: {e}")
                # If resampling fails, use new particles directly
                resampled_particles = new_particles
            
            # Add small noise to avoid particle collapse
            resampled_particles[:, :5] += np.random.normal(0, 0.05, (self.n_particles, 5))  # Position noise
            resampled_particles[:, 5:] += np.random.normal(0, 0.02, (self.n_particles, 5))  # Velocity noise
            
            # Ensure positions are within bounds
            resampled_particles[:, :5] = np.clip(resampled_particles[:, :5], 0, 10)
            
            return resampled_particles
            
        except Exception as e:
            if self.verbose:
                print(f"Error in SIR update: {e}")
            # If update fails completely, just add noise to current particles
            new_particles = self.particles.copy()
            new_particles[:, :5] += np.random.normal(0, 0.1, (self.n_particles, 5))
            new_particles[:, 5:] += np.random.normal(0, 0.05, (self.n_particles, 5))
            new_particles[:, :5] = np.clip(new_particles[:, :5], 0, 10)
            return new_particles
    
    def _apply_transition(self, state, action):
        """
        # & Apply transition model to a state
        # &
        # & Args:
        # &    state: Current state (10D)
        # &    action: Action taken
        # &
        # & Returns:
        # &    Next state
        """
        if self.transition_model:
            return self.transition_model(state, action)
        else:
            return self._simple_transition(state, action)
    
    def _simple_transition(self, state, action):
        """
        # & Simple transition model for Light-Dark 10D
        # &
        # & Args:
        # &    state: Current state (10D = 5D position + 5D velocity)
        # &    action: Action taken
        # &
        # & Returns:
        # &    Next state
        """
        # Extract position and velocity - ensure they're the right dimensions
        if len(state) != 10:
            # If state doesn't have the expected dimensions, create a valid state
            new_state = np.zeros(10)
            # Copy as much as possible from the original state
            new_state[:min(5, len(state))] = state[:min(5, len(state))]
            return new_state
        
        position = state[:5].copy()  # First 5 dimensions are position
        velocity = state[5:].copy()  # Last 5 dimensions are velocity
        
        # Apply force based on action (specific to Light-Dark 10D)
        force = np.zeros(5)  # Create a 5D force vector
        if isinstance(action, (int, np.integer)) and action < 10:
            dim = action // 2
            direction = 1 if action % 2 == 0 else -1
            if dim < 5:  # Ensure dimension is valid
                force[dim] = direction * 0.1
        
        # Update velocity with damping
        damping = 0.1
        velocity_delta = force - damping * velocity
        new_velocity = velocity + velocity_delta
        
        # Update position based on velocity
        dt = 0.1
        position_delta = velocity * dt
        new_position = position + position_delta
        
        # Clip position to map boundaries
        new_position = np.clip(new_position, 0, 10)
        
        # Combine position and velocity into new state
        new_state = np.concatenate([new_position, new_velocity])
        
        return new_state
    
    def get_belief_estimate(self):
        """
        # & Get the current belief estimate (particle set)
        # &
        # & Returns:
        # &    Particle array representing the current belief
        """
        return self.particles
    
    def select_action(self, observation=None):
        """
        # & Select action based on current belief
        # &
        # & Args:
        # &    observation: Current observation (optional)
        # &
        # & Returns:
        # &    Selected action
        """
        # For the Light-Dark 10D environment, use a simple goal-directed heuristic
        # Calculate mean belief position
        mean_pos = np.mean(self.particles[:, :5], axis=0)
        
        # Get direction to goal (assuming goal is at high coordinates in all dimensions)
        goal = np.array([8.0, 8.0, 8.0, 8.0, 8.0])  # Approximate goal position
        direction = goal - mean_pos
        
        # Find dimension with largest difference to goal
        max_dim = np.argmax(np.abs(direction))
        
        # Determine if we need to increase or decrease in that dimension
        if direction[max_dim] > 0:
            # Need to increase - action is 2*dim
            return 2 * max_dim
        else:
            # Need to decrease - action is 2*dim + 1
            return 2 * max_dim + 1
    
    def reset(self):
        """
        # & Reset belief to initial state
        """
        self.particles = self._init_particles()
        self.last_observation = None
