import torch

from dvrl.decoder_network import DecoderNetwork
from dvrl.encoder_network import EncoderNetwork
from dvrl.state_update_rnn import StateUpdateRNN
from dvrl.transition_network import TransitionNetwork


class ParticleFilter:
    """
    # & Implements particle-based belief representation and updates
    """
    def __init__(self, n_particles: int, transition_net: TransitionNetwork, 
                encoder_net: EncoderNetwork, decoder_net: DecoderNetwork, 
                state_update_rnn: StateUpdateRNN):
        self.n_particles = n_particles
        self.transition_net = transition_net
        self.encoder_net = encoder_net
        self.decoder_net = decoder_net
        self.state_update_rnn = state_update_rnn
    
    def init_particles(self, batch_size: int, h_dim: int, z_dim: int, device: torch.device):
        """
        # & Initialize particles
        """
        # & Initialize h particles with zeros
        h_particles = torch.zeros(batch_size, self.n_particles, h_dim, device=device)
        
        # & Initialize z particles with standard normal
        z_particles = torch.randn(batch_size, self.n_particles, z_dim, device=device)
        
        # & Initialize weights uniformly
        weights = torch.ones(batch_size, self.n_particles, device=device) / self.n_particles
        
        return h_particles, z_particles, weights
    
    def update(self, h_particles: torch.Tensor, z_particles: torch.Tensor, weights: torch.Tensor, 
            a_prev: torch.Tensor, o_curr: torch.Tensor):
        """
        # & Update particles given new action and observation
        
        Args:
            h_particles: Shape [batch_size, n_particles, h_dim]
            z_particles: Shape [batch_size, n_particles, z_dim]
            weights: Shape [batch_size, n_particles]
            a_prev: Shape [batch_size, action_dim]
            o_curr: Shape [batch_size, obs_dim]
            
        Returns:
            Updated h_particles, z_particles, weights, and ELBO term
        """
        batch_size = h_particles.shape[0]
        device = h_particles.device
        
        # & Expand action and observation for each particle
        a_prev_expanded = a_prev.unsqueeze(1).expand(-1, self.n_particles, -1)
        o_curr_expanded = o_curr.unsqueeze(1).expand(-1, self.n_particles, -1)
        
        # & Flatten batch and particles dimensions for processing
        h_flat = h_particles.reshape(-1, h_particles.shape[-1])
        a_flat = a_prev_expanded.reshape(-1, a_prev_expanded.shape[-1])
        o_flat = o_curr_expanded.reshape(-1, o_curr_expanded.shape[-1])
        
        # & Step 1: Resample particles based on weights
        ancestor_indices = self._resample(weights)
        h_resampled = torch.gather(h_particles, 1, 
                                ancestor_indices.unsqueeze(-1).expand(-1, -1, h_particles.shape[-1]))
        h_resampled_flat = h_resampled.reshape(-1, h_resampled.shape[-1])
        
        # & Step 2: Propose new latent states z_t using encoder
        encoder_mean, encoder_logvar = self.encoder_net(h_resampled_flat, a_flat, o_flat)
        encoder_std = torch.exp(0.5 * encoder_logvar)
        
        # & Sample using reparameterization trick
        eps = torch.randn_like(encoder_std)
        z_new = encoder_mean + encoder_std * eps
        
        # & Reshape z_new back to [batch_size, n_particles, z_dim]
        z_new = z_new.reshape(batch_size, self.n_particles, -1)
        z_flat = z_new.reshape(-1, z_new.shape[-1])
        
        # & Step 3: Update h using RNN
        h_new_flat = self.state_update_rnn(h_resampled_flat, z_flat, a_flat, o_flat)
        h_new = h_new_flat.reshape(batch_size, self.n_particles, -1)
        
        # & Step 4: Calculate importance weights
        # & Get p(z_t|h_{t-1}, a_{t-1}) from transition model
        prior_mean, prior_logvar = self.transition_net(h_resampled_flat, a_flat)
        prior_mean = prior_mean.reshape(batch_size, self.n_particles, -1)
        prior_logvar = prior_logvar.reshape(batch_size, self.n_particles, -1)
        prior_std = torch.exp(0.5 * prior_logvar)
        
        # & Get p(o_t|h_{t-1}, z_t, a_{t-1}) from decoder
        decoder_params = self.decoder_net(h_resampled_flat, z_flat, a_flat)
        decoder_params = decoder_params.reshape(batch_size, self.n_particles, -1)
        
        # & Compute log probabilities using proper distributions
        log_prior = self._log_normal_pdf(z_new, prior_mean, prior_std)
        log_encoder = self._log_normal_pdf(z_new, 
                                    encoder_mean.reshape(batch_size, self.n_particles, -1),
                                    encoder_std.reshape(batch_size, self.n_particles, -1))
        
        # & For observations, assume Bernoulli for simplicity (adjust based on your data)
        log_decoder = self._log_bernoulli_pdf(o_curr_expanded, torch.sigmoid(decoder_params))
        
        # & Compute importance weights: p(z_t|h_{t-1}, a_{t-1}) * p(o_t|h_{t-1}, z_t, a_{t-1}) / q(z_t|h_{t-1}, a_{t-1}, o_t)
        log_weights = log_prior + log_decoder - log_encoder
        
        # & Stabilize weights to prevent numerical issues
        log_weights = log_weights - log_weights.max(dim=1, keepdim=True)[0]
        weights_new = torch.softmax(log_weights, dim=1)
        
        # & Compute ELBO term: log(1/K * sum_{k=1}^K w_t^k)
        elbo = torch.log(torch.mean(torch.exp(log_weights.detach()), dim=1)).mean()
        
        return h_new, z_new, weights_new, elbo
    
    def _resample(self, weights: torch.Tensor):
        """
        # & Resample particles based on weights with improved shape handling
        """
        # & Ensure weights has the right shape [batch_size, n_particles]
        if len(weights.shape) > 2:
            # If we have extra dimensions, flatten them
            orig_shape = weights.shape
            weights = weights.reshape(orig_shape[0], -1)
        elif len(weights.shape) == 1:
            # If we have just a 1D tensor, add batch dimension
            weights = weights.unsqueeze(0)
                
        batch_size, n_particles = weights.shape
        device = weights.device
        
        # & Cumulative distribution function
        cdf = torch.cumsum(weights, dim=1)
        
        # & Sample uniform random numbers
        u = torch.rand(batch_size, 1, device=device)
        u = u + torch.arange(0, n_particles, device=device) / n_particles
        u = u % 1.0
        
        # & Find indices using searchsorted
        ancestors = torch.zeros(batch_size, n_particles, dtype=torch.long, device=device)
        
        for i in range(batch_size):
            ancestors[i] = torch.searchsorted(cdf[i], u[i])
        
        return ancestors
    
    def _log_normal_pdf(self, x, mean, std):
        """
        # & Compute log probability under normal distribution
        """
        return -0.5 * ((x - mean) / std).pow(2) - 0.5 * torch.log(2 * torch.tensor(3.14159, device=x.device)) - torch.log(std)
    
    def _log_bernoulli_pdf(self, x, p):
        """
        # & Compute log probability under Bernoulli distribution
        """
        return x * torch.log(p + 1e-8) + (1 - x) * torch.log(1 - p + 1e-8)

