from abc import ABC, abstractmethod
from typing import Tuple, Optional
import torch
import torch.nn as nn
from torch.distributions import Distribution


class BasePolicy(ABC):
    """Base policy interface for all algorithms.
    
    Provides consistent interface for policy networks across different algorithms.
    """
    
    @abstractmethod
    def forward(self, obs: torch.Tensor) -> Distribution:
        """Forward pass through policy network.
        
        Args:
            obs: Observation tensor
            
        Returns:
            Action distribution
        """
        pass
    
    @abstractmethod
    def get_action(
        self, 
        obs: torch.Tensor, 
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get action and log probability.
        
        Args:
            obs: Observation tensor
            deterministic: Whether to use deterministic action selection
            
        Returns:
            Tuple of (action, log_prob)
        """
        pass
    
    def get_action_and_value(
        self, 
        obs: torch.Tensor, 
        action: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get action, log probability, entropy, and value.
        
        Default implementation for actor-critic algorithms.
        Can be overridden for specific requirements.
        
        Args:
            obs: Observation tensor
            action: Optional action tensor for computing log_prob
            
        Returns:
            Tuple of (action, log_prob, entropy, value)
        """
        # This is a default implementation that assumes separate actor/critic
        # Individual policies should override this if needed
        action_dist = self.forward(obs)
        
        if action is None:
            action = action_dist.sample()
            
        log_prob = action_dist.log_prob(action)
        entropy = action_dist.entropy()
        
        # Value should be implemented by specific policy if needed
        value = torch.zeros(obs.shape[0], device=obs.device)
        
        return action, log_prob, entropy, value
    
    def _requires_grad(self, requires_grad: bool):
        """Set requires_grad for all parameters."""
        for param in self.parameters():
            param.requires_grad = requires_grad
        return self