import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Import the original DVRL implementation
from dvrl.dvrl import DVRL as OriginalDVRL

class DVRL:
    """
    # & Adapter for DVRL that makes it compatible with the Kidnapped Robot experiment script.
    # & 
    # & This wrapper properly maps parameters between the experiment runner and the original DVRL
    # & implementation, handling shape differences and tensor conversions.
    """
    
    def __init__(self, state_dim, belief_dim=10, n_particles=100, learning_rate=1e-3, 
                device=None, discrete_actions=True, verbose=False):
        """
        # & Initialize the DVRL adapter.
        # & 
        # & Args:
        # &     state_dim (int): Dimensionality of the state space (mapped to obs_dim)
        # &     belief_dim (int): Dimensionality of the latent belief space (mapped to z_dim)
        # &     n_particles (int): Number of particles to use
        # &     learning_rate (float): Learning rate for networks
        # &     device: PyTorch device (if None, will use CUDA if available)
        # &     discrete_actions (bool): Whether actions are discrete or continuous
        # &     verbose (bool): Whether to print detailed debug info
        """
        self.state_dim = state_dim
        self.belief_dim = belief_dim
        self.n_particles = n_particles
        self.discrete_actions = discrete_actions
        self.verbose = verbose
        
        # Initialize device
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Determine action_dim (will be set when select_action is called)
        self.action_dim = 4  # Default to 4 actions for kidnapped robot (up, down, left, right)
        
        # Original DVRL model
        self.model = None
        
        # For tracking history
        self.history = []  # Store history of (action, observation) pairs
        
        # For current belief state
        self.h_particles = None
        self.z_particles = None
        self.weights = None
        
        # Optimization parameters
        self.optimizer = None
        self.learning_rate = learning_rate
        
        # For tracking performance
        self.training_loss = []
        
        # For calculating belief particles in case of fallback
        self.belief_particles = np.random.normal(0.5, 0.1, (self.n_particles, self.state_dim))
        
        # Flag to track if we've switched to fallback mode
        self.fallback_mode = False
        
        # Internal state tracking
        self.dvrl_n_particles = min(20, self.n_particles)  # DVRL internal particle count
    
    def log(self, message):
        """Helper method to conditionally print debug messages"""
        if self.verbose:
            print(message)
    
    def _initialize_model(self, action_dim):
        """
        # & Initialize the DVRL model when action dimension is known.
        # & 
        # & Args:
        # &     action_dim: Dimension of the action space
        """
        self.action_dim = action_dim
        
        try:
            # DVRL expects obs_dim, action_dim, h_dim, z_dim, n_particles
            h_dim = 64  # Using a smaller hidden dimension than the default 256
            
            # CRITICAL FIX: Set the z_dim equal to the state_dim for compatibility
            z_dim = self.state_dim
            
            self.log(f"Initializing DVRL with state_dim={self.state_dim}, setting z_dim={z_dim}")
            
            # Create the original DVRL model with our parameters
            self.model = OriginalDVRL(
                obs_dim=self.state_dim,
                action_dim=self.action_dim,
                h_dim=h_dim,
                z_dim=z_dim,  # Using state_dim as z_dim to ensure compatibility
                n_particles=self.dvrl_n_particles,
                continuous_actions=not self.discrete_actions,
                hidden_layers=1,
                action_factor=0.5,
                rnn_type='gru'
            ).to(self.device)
            
            # Initialize optimizer
            params = list(self.model.parameters())
            self.optimizer = optim.Adam(params, lr=self.learning_rate)
            
            # Initialize belief state
            batch_size = 1  # For inference we use batch size of 1
            self.h_particles, self.z_particles, self.weights = self.model.init_belief(batch_size, self.device)
            
            self.log(f"Successfully initialized DVRL with shapes: h={self.h_particles.shape}, z={self.z_particles.shape}, weights={self.weights.shape}")
            self.fallback_mode = False
            
            return True
        except Exception as e:
            print(f"Error creating DVRL model: {e}")
            print("Using fallback belief representation")
            self.fallback_mode = True
            return False
    
    def _tensor_to_np(self, tensor):
        """
        # & Convert PyTorch tensor to NumPy array safely
        """
        if isinstance(tensor, torch.Tensor):
            return tensor.detach().cpu().numpy()
        return tensor
    
    def update(self, action, observation, transition_model=None, observation_model=None):
        """
        # & Update belief using new action and observation.
        # & 
        # & Args:
        # &     action: Action taken
        # &     observation: Observation received
        # &     transition_model: Function that models transitions
        # &     observation_model: Function that computes observation likelihoods
        """
        # Append to history
        self.history.append((action, observation))
        
        # Check if model is initialized
        if self.model is None or self.fallback_mode:
            if self.model is None:
                # Initialize model if not done yet
                if self.discrete_actions:
                    if isinstance(action, (int, np.integer)):
                        action_dim = max(action, 3) + 1  # Ensure at least 4 actions
                    else:
                        action_dim = len(action) if hasattr(action, '__len__') else 4
                else:
                    action_dim = len(action) if hasattr(action, '__len__') else 1
                
                success = self._initialize_model(action_dim)
                if not success:
                    # Use fallback particle filter
                    self._fallback_update(action, observation, transition_model, observation_model)
                    return
            else:
                # Already in fallback mode
                self._fallback_update(action, observation, transition_model, observation_model)
                return
        
        try:
            # Convert action to tensor
            if self.discrete_actions and isinstance(action, (int, np.integer)):
                a_tensor = torch.zeros(1, self.action_dim, device=self.device)
                a_tensor[0, action] = 1.0
            else:
                # Ensure action is numpy array
                if not isinstance(action, np.ndarray):
                    action = np.array(action)
                    
                # Ensure action has correct shape
                if len(action.shape) == 0:
                    action = np.array([action])
                elif len(action.shape) == 1 and action.shape[0] != self.action_dim:
                    # Pad or truncate
                    if action.shape[0] < self.action_dim:
                        padded = np.zeros(self.action_dim)
                        padded[:action.shape[0]] = action
                        action = padded
                    else:
                        action = action[:self.action_dim]
                
                a_tensor = torch.FloatTensor(action).reshape(1, -1).to(self.device)
            
            # Convert observation to tensor
            if isinstance(observation, (int, float, np.integer, np.floating)):
                # Single value observation
                o_tensor = torch.FloatTensor([[observation]]).to(self.device)
            else:
                # Try to convert to numpy array
                if not isinstance(observation, np.ndarray):
                    observation = np.array(observation)
                
                # Ensure observation has correct shape
                if len(observation.shape) == 1:
                    observation = observation.reshape(1, -1)
                
                # Check observation dimension
                if observation.shape[1] != self.state_dim:
                    # Pad or truncate
                    if observation.shape[1] < self.state_dim:
                        padded = np.zeros((1, self.state_dim))
                        padded[0, :observation.shape[1]] = observation
                        observation = padded
                    else:
                        observation = observation[:, :self.state_dim]
                
                o_tensor = torch.FloatTensor(observation).to(self.device)
            
            # Update belief
            try:
                # Make sure tensors have the right shape before calling update_belief
                if a_tensor.shape[1] != self.action_dim:
                    a_tensor = a_tensor[:, :self.action_dim]
                    
                if o_tensor.shape[1] != self.state_dim:
                    o_tensor = o_tensor[:, :self.state_dim]
                
                # Update belief state using the model's update_belief method
                try:
                    # Try to use model's particle filter update method
                    update_result = self.model.particle_filter.update(
                        self.h_particles, self.z_particles, self.weights, a_tensor, o_tensor
                    )
                    
                    # Handle different return types
                    if isinstance(update_result, tuple):
                        # Extract what we need from the result
                        if len(update_result) >= 3:
                            # Get the first 3 elements
                            self.h_particles = update_result[0]
                            self.z_particles = update_result[1]
                            
                            # CRITICAL FIX: Fix the weights tensor shape
                            weights = update_result[2]
                            
                            # Check if weights has the wrong shape
                            if len(weights.shape) == 3:
                                # If weights has shape [batch_size, n_particles, z_dim], 
                                # convert to [batch_size, n_particles]
                                self.weights = weights[:, :, 0]
                                
                                # Alternative: sum over last dimension
                                if weights.shape[2] > 1 and torch.all(weights[:, :, 0] <= 0):
                                    # Use mean across dimensions if first element doesn't look like proper weights
                                    self.weights = torch.mean(weights, dim=2)
                                
                                # Ensure weights are normalized
                                self.weights = F.softmax(self.weights, dim=1)
                            else:
                                self.weights = weights
                        else:
                            raise ValueError(f"Expected at least 3 elements in update result")
                    
                    elif isinstance(update_result, dict):
                        # If it's a dictionary, extract by keys
                        # Try common key names
                        for h_key in ['h', 'h_particles', 'hidden']:
                            if h_key in update_result:
                                self.h_particles = update_result[h_key]
                                break
                        
                        for z_key in ['z', 'z_particles', 'latent']:
                            if z_key in update_result:
                                self.z_particles = update_result[z_key]
                                break
                        
                        for w_key in ['w', 'weights', 'particle_weights']:
                            if w_key in update_result:
                                weights = update_result[w_key]
                                
                                # Fix weights shape if needed
                                if len(weights.shape) == 3:
                                    self.weights = weights[:, :, 0]
                                else:
                                    self.weights = weights
                                break
                    
                    # Ensure weights are the right shape for later operations
                    if len(self.weights.shape) != 2:
                        if len(self.weights.shape) > 2:
                            # Too many dimensions, reduce to [batch_size, n_particles]
                            self.weights = self.weights.mean(dim=tuple(range(2, len(self.weights.shape))))
                        elif len(self.weights.shape) == 1:
                            # Missing batch dimension
                            self.weights = self.weights.unsqueeze(0)
                    
                except Exception as e:
                    # Try using model's update_belief as fallback
                    try:
                        # Create wrapped update_belief function to handle unexpected returns
                        def safe_update(h, z, w, a, o):
                            # Try to use the model's update_belief method
                            result = self.model.update_belief(h, z, w, a, o)
                            
                            # Ensure we get three tensors back
                            if isinstance(result, tuple):
                                if len(result) >= 3:
                                    h_new = result[0]
                                    z_new = result[1]
                                    w_new = result[2]
                                    
                                    # Fix weights shape if needed
                                    if len(w_new.shape) == 3:
                                        w_new = w_new[:, :, 0]  # Take first slice
                                    
                                    return h_new, z_new, w_new
                            
                            # If we couldn't process the result, reinitialize
                            return self.model.init_belief(h.shape[0], h.device)
                        
                        # Call our safe update function
                        self.h_particles, self.z_particles, self.weights = safe_update(
                            self.h_particles, self.z_particles, self.weights, a_tensor, o_tensor
                        )
                    
                    except Exception as inner_e:
                        self.fallback_mode = True
                        self._fallback_update(action, observation, transition_model, observation_model)
                        return
                
                # Use data from DVRL to update our fallback particles
                try:
                    # Ensure weights have the right shape for aggregation
                    weights_for_aggr = self.weights
                    if len(weights_for_aggr.shape) != 2:
                        if len(weights_for_aggr.shape) > 2:
                            weights_for_aggr = weights_for_aggr.mean(dim=tuple(range(2, len(weights_for_aggr.shape))))
                        elif len(weights_for_aggr.shape) == 1:
                            weights_for_aggr = weights_for_aggr.unsqueeze(0)
                    
                    # Directly compute weighted average of z_particles
                    try:
                        # Manual aggregation as fallback
                        weights_expanded = weights_for_aggr.unsqueeze(-1).expand(-1, -1, self.z_particles.shape[-1])
                        belief_vector = torch.sum(self.z_particles * weights_expanded, dim=1)
                        center = belief_vector.detach().cpu().numpy()[0]
                    except Exception as aggr_e:
                        # Try to use model's aggregator
                        belief_vector = self.model.particle_aggregator(
                            self.h_particles, self.z_particles, weights_for_aggr
                        )
                        center = belief_vector.detach().cpu().numpy()[0]
                    
                    # Ensure the center vector has the expected dimension
                    if len(center) != self.state_dim:
                        if len(center) > self.state_dim:
                            center = center[:self.state_dim]
                        else:
                            padded_center = np.zeros(self.state_dim)
                            padded_center[:len(center)] = center
                            center = padded_center
                    
                    # Generate particles around the belief center
                    self.belief_particles = np.random.normal(
                        center, 0.1, (self.n_particles, self.state_dim)
                    )
                    
                except Exception as e:
                    if self.verbose:
                        print(f"Error updating fallback particles: {e}")
                    # Just use what we have
                    if not self.fallback_mode:
                        # Add some noise to current particles
                        self.belief_particles += np.random.normal(0, 0.1, self.belief_particles.shape)
                    
            except Exception as e:
                if self.verbose:
                    print(f"Warning: Error in DVRL belief update: {e}. Using fallback.")
                
                # Switch to fallback mode
                self.fallback_mode = True
                self._fallback_update(action, observation, transition_model, observation_model)
                
        except Exception as e:
            if self.verbose:
                print(f"Error processing inputs: {e}")
            # Use fallback update
            self.fallback_mode = True
            self._fallback_update(action, observation, transition_model, observation_model)
    
    def _fallback_update(self, action, observation, transition_model, observation_model):
        """
        # & Fallback belief update using simple particle filter
        """
        if self.verbose:
            print("Using fallback particle filter update...")
            
        if transition_model is not None and observation_model is not None:
            try:
                # Apply transition model to each particle
                predicted_particles = np.zeros_like(self.belief_particles)
                for i in range(self.n_particles):
                    predicted_particles[i] = transition_model(self.belief_particles[i], action)
                
                # Calculate weights using observation model
                weights = np.zeros(self.n_particles)
                for i in range(self.n_particles):
                    try:
                        weights[i] = observation_model(predicted_particles[i], observation)
                    except:
                        weights[i] = 1e-10  # Small positive value if error
                
                # Normalize weights
                if np.sum(weights) > 0:
                    weights = weights / np.sum(weights)
                else:
                    weights = np.ones(self.n_particles) / self.n_particles
                
                # Resample particles
                try:
                    indices = np.random.choice(self.n_particles, self.n_particles, p=weights, replace=True)
                    self.belief_particles = predicted_particles[indices]
                except:
                    # If resampling fails, keep predicted particles
                    self.belief_particles = predicted_particles
                
                # Add small noise
                self.belief_particles += np.random.normal(0, 0.05, self.belief_particles.shape)
            except Exception as e:
                if self.verbose:
                    print(f"Error in fallback update: {e}")
                # Just add noise to current particles
                self.belief_particles += np.random.normal(0, 0.1, self.belief_particles.shape)
        else:
            # Without models, just add random noise to particles
            self.belief_particles += np.random.normal(0, 0.1, self.belief_particles.shape)
    
    def get_belief_estimate(self):
        """
        # & Get the current belief estimate as particles.
        # & 
        # & Returns:
        # &     numpy array of particles
        """
        if self.fallback_mode or self.model is None:
            return self.belief_particles
        
        try:
            # Create output particles
            particles = np.zeros((self.n_particles, self.state_dim))
            
            # Try to generate particles from model
            with torch.no_grad():
                try:
                    # Ensure weights have the right shape for aggregation
                    weights_for_aggr = self.weights
                    if len(weights_for_aggr.shape) != 2:
                        if len(weights_for_aggr.shape) > 2:
                            weights_for_aggr = weights_for_aggr.mean(dim=tuple(range(2, len(weights_for_aggr.shape))))
                        elif len(weights_for_aggr.shape) == 1:
                            weights_for_aggr = weights_for_aggr.unsqueeze(0)
                    
                    # Try to manually aggregate particles with weights
                    try:
                        weights_expanded = weights_for_aggr.unsqueeze(-1).expand(-1, -1, self.z_particles.shape[-1])
                        belief_vector = torch.sum(self.z_particles * weights_expanded, dim=1)
                        
                    except Exception as manual_e:
                        if self.verbose:
                            print(f"Manual aggregation failed: {manual_e}")
                        # Try model's aggregator
                        belief_vector = self.model.particle_aggregator(
                            self.h_particles, self.z_particles, weights_for_aggr
                        )
                    
                    center = belief_vector.cpu().numpy()[0]
                    
                    # Ensure center has right dimension
                    if len(center) != self.state_dim:
                        if len(center) > self.state_dim:
                            center = center[:self.state_dim]
                        else:
                            new_center = np.zeros(self.state_dim)
                            new_center[:len(center)] = center
                            center = new_center
                    
                    # Generate particles around the center
                    for i in range(self.n_particles):
                        # Add noise scaled to belief dimension
                        noise = np.random.normal(0, 0.1, self.state_dim)
                        particles[i] = center + noise
                    
                except Exception as e:
                    if self.verbose:
                        print(f"Error in particle aggregation: {e}")
                    # Use the z_particles directly if available
                    if self.z_particles is not None:
                        # Sample particles based on weights
                        try:
                            weights_np = self.weights.cpu().numpy()[0]
                            if len(weights_np) != self.z_particles.shape[1]:
                                if self.verbose:
                                    print(f"Weight shape mismatch: weights={len(weights_np)}, particles={self.z_particles.shape[1]}")
                                weights_np = np.ones(self.z_particles.shape[1]) / self.z_particles.shape[1]
                            
                            # Normalize weights
                            weights_np = weights_np / np.sum(weights_np)
                            
                            # Sample indices based on weights
                            indices = np.random.choice(
                                self.z_particles.shape[1], 
                                size=min(self.n_particles, self.z_particles.shape[1]),
                                p=weights_np, 
                                replace=True
                            )
                            
                            # Extract particles
                            z_np = self.z_particles.cpu().numpy()[0]
                            particles[:len(indices)] = z_np[indices]
                            
                            # Fill remaining particles if needed
                            if len(indices) < self.n_particles:
                                particles[len(indices):] = particles[:len(indices)][
                                    np.random.choice(len(indices), self.n_particles - len(indices))
                                ]
                        
                        except Exception as sampling_e:
                            if self.verbose:
                                print(f"Error sampling particles: {sampling_e}")
                            # Just use random subset and add noise
                            z_np = self.z_particles.cpu().numpy()[0]
                            indices = np.random.choice(z_np.shape[0], min(self.n_particles, z_np.shape[0]))
                            particles[:len(indices)] = z_np[indices]
                            
                            # Add noise
                            particles += np.random.normal(0, 0.1, particles.shape)
                
                return particles
                
        except Exception as e:
            if self.verbose:
                print(f"Error generating belief particles: {e}")
        
        # Fallback to directly tracked particles
        return self.belief_particles
    
    def select_action(self, action_space=None, deterministic=False):
        """
        # & Select action based on current belief.
        # & 
        # & Args:
        # &     action_space: Action space (used for initialization)
        # &     deterministic: Whether to use deterministic selection
        # &     
        # & Returns:
        # &     Selected action
        """
        # Initialize model if not already done
        if self.model is None:
            if action_space is None:
                raise ValueError("Action space must be provided on first call to select_action")
            
            # Determine action dimension from action space
            if self.discrete_actions:
                if isinstance(action_space, (list, np.ndarray)):
                    action_dim = len(action_space)
                elif hasattr(action_space, 'n'):
                    action_dim = action_space.n
                else:
                    action_dim = 4  # Default to 4 discrete actions
            else:
                if hasattr(action_space, 'shape'):
                    action_dim = action_space.shape[0]
                else:
                    action_dim = len(action_space[0]) if isinstance(action_space, (list, tuple)) else 1
            
            success = self._initialize_model(action_dim)
            if not success:
                # Use random action if initialization failed
                if self.discrete_actions:
                    return np.random.randint(0, self.action_dim)
                else:
                    return np.random.uniform(-1, 1, self.action_dim)
        
        if not self.fallback_mode:
            try:
                # Ensure weights have proper shape for action selection
                weights_for_act = self.weights
                if len(weights_for_act.shape) != 2:
                    if len(weights_for_act.shape) > 2:
                        weights_for_act = weights_for_act.mean(dim=tuple(range(2, len(weights_for_act.shape))))
                    elif len(weights_for_act.shape) == 1:
                        weights_for_act = weights_for_act.unsqueeze(0)
                
                # Select action using model
                with torch.no_grad():
                    try:
                        action, _, _, _ = self.model.act(
                            self.h_particles, 
                            self.z_particles, 
                            weights_for_act, 
                            deterministic=deterministic
                        )
                    except Exception as act_e:
                        if self.verbose:
                            print(f"Error in model.act: {act_e}")
                        # Use simpler method - just use the mean belief vector
                        weights_expanded = weights_for_act.unsqueeze(-1).expand(-1, -1, self.z_particles.shape[-1])
                        belief_vector = torch.sum(self.z_particles * weights_expanded, dim=1)
                        
                        # Pass through policy network
                        if self.discrete_actions:
                            logits = self.model.policy_net(belief_vector)
                            if deterministic:
                                action = torch.argmax(logits, dim=-1)
                            else:
                                probs = F.softmax(logits, dim=-1)
                                dist = torch.distributions.Categorical(probs)
                                action = dist.sample()
                        else:
                            mean, log_std = self.model.policy_net(belief_vector)
                            if deterministic:
                                action = mean
                            else:
                                std = torch.exp(log_std)
                                action = mean + std * torch.randn_like(std)
                    
                    # Convert to numpy and appropriate format
                    if self.discrete_actions:
                        action = action.cpu().numpy().item()
                    else:
                        action = action.cpu().numpy().squeeze()
                
                return action
            except Exception as e:
                if self.verbose:
                    print(f"Error selecting action: {e}")
                # Switch to fallback mode
                self.fallback_mode = True
        
        # Fallback - random action
        if self.discrete_actions:
            return np.random.randint(0, self.action_dim)
        else:
            return np.random.uniform(-1, 1, self.action_dim)
    
    def reset(self):
        """
        # & Reset the agent's history and belief
        """
        self.history = []
        self.belief_particles = np.random.normal(0.5, 0.1, (self.n_particles, self.state_dim))
        self.fallback_mode = False  # Try non-fallback mode again on reset
        
        if self.model is not None:
            try:
                # Reset belief state
                batch_size = 1  # For inference we use batch size of 1
                self.h_particles, self.z_particles, self.weights = self.model.init_belief(batch_size, self.device)
            except Exception as e:
                if self.verbose:
                    print(f"Error resetting DVRL belief: {e}")
                self.fallback_mode = True
