import torch
import torch.nn as nn
import torch.nn.functional as F

from dvrl.decoder_network import DecoderNetwork
from dvrl.encoder_network import EncoderNetwork
from dvrl.particle_aggregator import ParticleAggregator
from dvrl.particle_filter import ParticleFilter
from dvrl.policy_network import PolicyNetwork
from dvrl.state_update_rnn import StateUpdateRNN
from dvrl.transition_network import TransitionNetwork
from dvrl.value_network import ValueNetwork


class DVRL(nn.Module):
    """
    # & Deep Variational Reinforcement Learning with fixed dimension handling
    """
    def __init__(self, obs_dim: int, action_dim: int, h_dim: int = 256, z_dim: int = 64, 
                n_particles: int = 16, continuous_actions: bool = False,
                hidden_layers: int = 1, action_factor: float = 0.5,
                rnn_type: str = 'gru'):
        super(DVRL, self).__init__()
        
        # Store original dimensions
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.n_particles = n_particles
        self.continuous_actions = continuous_actions
        
        # CRITICAL: For dimension compatibility, ensure z_dim matches obs_dim if needed
        # This is to handle the specific dimension mismatch error in the update function
        # (This is a workaround for the specific error encountered)
        self._internal_z_dim = obs_dim if obs_dim < 100 else z_dim
        
        print(f"DVRL initialized with obs_dim={obs_dim}, z_dim={z_dim}, using internal z_dim={self._internal_z_dim}")
        
        # Core networks for DVRL - use internal z_dim for compatibility
        self.encoder_net = EncoderNetwork(h_dim, self._internal_z_dim, action_dim, obs_dim, hidden_layers, action_factor)
        self.decoder_net = DecoderNetwork(h_dim, self._internal_z_dim, action_dim, obs_dim, hidden_layers, action_factor)
        self.transition_net = TransitionNetwork(h_dim, self._internal_z_dim, action_dim, hidden_layers, action_factor)
        self.state_update_rnn = StateUpdateRNN(h_dim, self._internal_z_dim, action_dim, obs_dim, rnn_type, action_factor)
        
        # Particle filter
        self.particle_filter = ParticleFilter(n_particles, self.transition_net, 
                                            self.encoder_net, self.decoder_net,
                                            self.state_update_rnn)
        
        # Particle aggregator
        self.particle_aggregator = ParticleAggregator(h_dim, self._internal_z_dim)
        
        # Policy and value networks
        self.policy_net = PolicyNetwork(h_dim, action_dim, h_dim, continuous_actions)
        self.value_net = ValueNetwork(h_dim, h_dim)
    
    def init_belief(self, batch_size: int, device: torch.device):
        """
        # & Initialize belief state (particles)
        """
        return self.particle_filter.init_particles(batch_size, self.h_dim, self._internal_z_dim, device)
    
    def update_belief(self, h_particles: torch.Tensor, z_particles: torch.Tensor, 
                    weights: torch.Tensor, a_prev: torch.Tensor, o_curr: torch.Tensor):
        """
        # & Update belief state given new action and observation
        # & 
        # & Args:
        # &     h_particles: Hidden state particles [batch_size, n_particles, h_dim]
        # &     z_particles: Latent state particles [batch_size, n_particles, z_dim]
        # &     weights: Particle weights [batch_size, n_particles]
        # &     a_prev: Previous action [batch_size, action_dim]
        # &     o_curr: Current observation [batch_size, obs_dim]
        # & 
        # & Returns:
        # &     tuple: Updated (h_particles, z_particles, weights)
        """
        # Debug info
        print(f"update_belief input shapes: h={h_particles.shape}, z={z_particles.shape}, w={weights.shape}, a={a_prev.shape}, o={o_curr.shape}")
        
        # Check if z_particles dimension matches our expected internal dimension
        if z_particles.shape[-1] != self._internal_z_dim:
            print(f"⚠️ Dimension mismatch: z_particles has shape {z_particles.shape}, but expected last dim {self._internal_z_dim}")
            print("Reshaping z_particles to fix dimension mismatch...")
            
            # Create a new z_particles tensor with the correct dimensions
            batch_size = z_particles.shape[0]
            n_particles = z_particles.shape[1]
            
            # Two approaches to handle this mismatch:
            if z_particles.shape[-1] < self._internal_z_dim:
                # Option 1: Expand by padding with zeros
                new_z = torch.zeros(batch_size, n_particles, self._internal_z_dim, device=z_particles.device)
                new_z[:, :, :z_particles.shape[-1]] = z_particles
                z_particles = new_z
            else:
                # Option 2: Truncate to fit
                z_particles = z_particles[:, :, :self._internal_z_dim]
            
            print(f"Reshaped z_particles to {z_particles.shape}")
        
        # Now pass to the particle filter update
        return self.particle_filter.update(h_particles, z_particles, weights, a_prev, o_curr)
    
    def act(self, h_particles: torch.Tensor, z_particles: torch.Tensor, 
        weights: torch.Tensor, deterministic: bool = False):
        """
        # & Select action based on current belief state
        """
        # Check if z_particles dimension matches our expected internal dimension
        if z_particles.shape[-1] != self._internal_z_dim:
            print(f"⚠️ Dimension mismatch in act(): z_particles has shape {z_particles.shape}, but expected last dim {self._internal_z_dim}")
            
            # Create a new z_particles tensor with the correct dimensions
            batch_size = z_particles.shape[0]
            n_particles = z_particles.shape[1]
            
            if z_particles.shape[-1] < self._internal_z_dim:
                # Expand by padding with zeros
                new_z = torch.zeros(batch_size, n_particles, self._internal_z_dim, device=z_particles.device)
                new_z[:, :, :z_particles.shape[-1]] = z_particles
                z_particles = new_z
            else:
                # Truncate to fit
                z_particles = z_particles[:, :, :self._internal_z_dim]
            
            print(f"Reshaped z_particles to {z_particles.shape}")
            
        # Aggregate particles into a single representation
        belief_vector = self.particle_aggregator(h_particles, z_particles, weights)
        
        # Forward through policy network
        if self.continuous_actions:
            mean, log_std = self.policy_net(belief_vector)
            std = torch.exp(log_std)
            
            if deterministic:
                action = mean
            else:
                action = mean + std * torch.randn_like(std)
                
            log_prob = -0.5 * ((action - mean) / std).pow(2) - log_std - 0.5 * torch.log(2 * torch.tensor(3.14159, device=mean.device))
            log_prob = log_prob.sum(-1)
            entropy = 0.5 + 0.5 * torch.log(2 * torch.tensor(3.14159, device=mean.device)) + log_std.sum(-1)
        else:
            logits = self.policy_net(belief_vector)
            probs = F.softmax(logits, dim=-1)
            
            if deterministic:
                action = torch.argmax(logits, dim=-1)
            else:
                dist = torch.distributions.Categorical(probs)
                action = dist.sample()
                
            log_prob = torch.log(probs.gather(1, action.unsqueeze(-1)).squeeze(-1) + 1e-8)
            entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
        
        # Get value estimate
        value = self.value_net(belief_vector)
        
        return action, log_prob, entropy, value
    
    def get_value(self, h_particles: torch.Tensor, z_particles: torch.Tensor, weights: torch.Tensor):
        """
        # & Get value estimate for current belief state
        """
        # Check dimensions and reshape if needed
        if z_particles.shape[-1] != self._internal_z_dim:
            # Create a new z_particles tensor with the correct dimensions
            batch_size = z_particles.shape[0]
            n_particles = z_particles.shape[1]
            
            if z_particles.shape[-1] < self._internal_z_dim:
                # Expand by padding with zeros
                new_z = torch.zeros(batch_size, n_particles, self._internal_z_dim, device=z_particles.device)
                new_z[:, :, :z_particles.shape[-1]] = z_particles
                z_particles = new_z
            else:
                # Truncate to fit
                z_particles = z_particles[:, :, :self._internal_z_dim]
        
        belief_vector = self.particle_aggregator(h_particles, z_particles, weights)
        return self.value_net(belief_vector)
    
    def forward(self, obs, prev_action=None, hidden_state=None):
        """
        # & Forward pass through the DVRL model
        # &
        # & Args:
        # &     obs (torch.Tensor): Current observation
        # &     prev_action (torch.Tensor, optional): Previous action
        # &     hidden_state (tuple, optional): Previous belief state (h_particles, z_particles, weights)
        # &
        # & Returns:
        # &     tuple: (belief_vector, hidden_state)
        """
        batch_size = obs.shape[0]
        device = obs.device
        
        # Initialize belief if this is the first call
        if hidden_state is None:
            h_particles, z_particles, weights = self.init_belief(batch_size, device)
        else:
            h_particles, z_particles, weights = hidden_state
            
            # Check dimensions and reshape if needed
            if z_particles.shape[-1] != self._internal_z_dim:
                # Create a new z_particles tensor with the correct dimensions
                batch_size = z_particles.shape[0]
                n_particles = z_particles.shape[1]
                
                if z_particles.shape[-1] < self._internal_z_dim:
                    # Expand by padding with zeros
                    new_z = torch.zeros(batch_size, n_particles, self._internal_z_dim, device=z_particles.device)
                    new_z[:, :, :z_particles.shape[-1]] = z_particles
                    z_particles = new_z
                else:
                    # Truncate to fit
                    z_particles = z_particles[:, :, :self._internal_z_dim]
        
        # Update belief if previous action exists
        if prev_action is not None:
            h_particles, z_particles, weights = self.update_belief(
                h_particles, z_particles, weights, prev_action, obs)
        
        # Get aggregated belief representation 
        belief_vector = self.particle_aggregator(h_particles, z_particles, weights)
        
        # Return belief vector and hidden state for recurrent usage
        return belief_vector, (h_particles, z_particles, weights)
    
    def sample(self, n_samples=None):
        """
        # & Sample from the belief distribution
        # &
        # & Args:
        # &     n_samples (int, optional): Number of samples to generate
        # &
        # & Returns:
        # &     torch.Tensor: Samples from the belief distribution
        """
        if n_samples is None:
            n_samples = self.n_particles
        
        # Initialize a dummy batch
        device = next(self.parameters()).device
        h_particles, z_particles, weights = self.init_belief(1, device)
        
        # Sample from the belief distribution according to weights
        if n_samples <= self.n_particles:
            # Just use some of our particles
            indices = torch.multinomial(weights[0], n_samples, replacement=True)
            samples = z_particles[0, indices]
        else:
            # Need to generate more samples than we have particles
            # Repeat sampling with replacement
            indices = torch.multinomial(weights[0], n_samples, replacement=True)
            samples = z_particles[0, indices]
        
        return samples
