from abc import abstractmethod
from typing import Tuple, Optional
import copy
import torch
import torch.nn as nn
import torch.distributions as distributions
from .base import BasePolicy

class ActorCriticPolicy(BasePolicy):
    """Actor-Critic policy implementation.
    
    Combines actor and critic networks for algorithms like PPO, A2C.
    Follows CleanRL patterns with clean network definitions.
    """

    def __init__(self, actor_network: nn.Module, critic_network: nn.Module, action_dim: int, shared_backbone: nn.Module = None):
        """Initialize actor-critic policy.
        Args:
            actor_network: Actor network that outputs action distribution parameters
            critic_network: Critic network that outputs state values
            action_dim: Dimension of the action space
            shared_backbone: Optional shared backbone network for feature extraction. If provided, it will be used to process observations before passing them to actor and critic networks.
        """
        
        self.shared_backbone = shared_backbone
        self.actor_network = actor_network
        self.critic_network = critic_network
        
        if shared_backbone is not None:
            if actor_network is None or critic_network is None:
                if action_dim is None:
                    raise ValueError("action_dim must be specified if shared_backbone is not None")
                # Use provided shared network and add actor and critic heads
                self.actor_network = nn.Linear(shared_backbone.out_dim, action_dim)
                self.critic_network = nn.Linear(shared_backbone.out_dim, 1)
        

        
    def parameters(self):
        """Get all policy parameters."""
        if self.shared_backbone is not None:
            actor_params = list(self.actor_network.parameters())
            critic_params = list(self.critic_network.parameters())
            return list(self.shared_backbone.parameters()) + actor_params + critic_params
        else:
            return list(self.actor_network.parameters()) + list(self.critic_network.parameters())
    
    def actor_forward(self, obs: torch.Tensor) -> torch.Tensor:
        """Forward pass through actor network.
        
        Args:
            obs: Observation tensor
            
        Returns:
            Action distribution parameters
        """
        x = self.shared_backbone(obs) if self.shared_backbone is not None else obs
        return self.actor_network(x)

    def critic_forward(self, obs: torch.Tensor) -> torch.Tensor:
        """Forward pass through critic network.

        Args:
            obs: Observation tensor

        Returns:
            State value
        """
        x = self.shared_backbone(obs) if self.shared_backbone is not None else obs
        return self.critic_network(x)

    def get_value(self, obs: torch.Tensor) -> torch.Tensor:
        """Get state value.
        
        Args:
            obs: Observation tensor
            
        Returns:
            State value tensor
        """
        return self.critic_forward(obs).flatten()

    def to(self, device):
        """Move policy to device."""
        if self.shared_backbone is not None:
            self.shared_backbone.to(device)
        self.actor_network.to(device)
        self.critic_network.to(device)
        return self
    
    def _train(self):
        """Set policy to training mode."""
        if self.shared_backbone is not None:
            self.shared_backbone.train()
        self.actor_network.train()
        self.critic_network.train()
    
    def _eval(self):
        """Set policy to evaluation mode."""
        if self.shared_backbone is not None:
            self.shared_backbone.eval()
        self.actor_network.eval()
        self.critic_network.eval()
    
    def _compile(self, **kwargs):
        """Compile policy networks for performance."""
        if self.shared_backbone is not None:
            self.shared_backbone = torch.compile(self.shared_backbone, **kwargs)
        self.actor_network = torch.compile(self.actor_network, **kwargs)
        self.critic_network = torch.compile(self.critic_network, **kwargs)
        return self
    
    def state_dict(self):
        """Get state dict of policy."""
        state = {}
        if self.shared_backbone is not None:
            state['shared_backbone'] = self.shared_backbone.state_dict()
        state['actor_network'] = self.actor_network.state_dict()
        state['critic_network'] = self.critic_network.state_dict()
        return state
    
    def load_state_dict(self, state_dict):
        """Load state dict into policy."""
        if self.shared_backbone is not None:
            self.shared_backbone.load_state_dict(state_dict['shared_backbone'])
        else:
            if 'shared_backbone' in state_dict:
                raise ValueError("State dict contains shared_backbone but this policy does not use it")

        self.actor_network.load_state_dict(state_dict['actor_network'])
        self.critic_network.load_state_dict(state_dict['critic_network'])
        return self


class DiscreteActorCriticPolicy(ActorCriticPolicy):
    """Actor-Critic policy for discrete action spaces.
    
    Uses Categorical distribution for action sampling.
    Optimized for discrete control tasks.
    """
    
    def forward(self, obs: torch.Tensor) -> distributions.Categorical:
        """Forward pass through actor network.
        
        Args:
            obs: Observation tensor
            
        Returns:
            Categorical action distribution
        """
        logits = self.actor_forward(obs)
        return distributions.Categorical(logits=logits)
    
    def get_action(
        self, 
        obs: torch.Tensor, 
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get action and log probability.
        
        Args:
            obs: Observation tensor
            deterministic: Whether to use deterministic action (argmax)
            
        Returns:
            Tuple of (action, log_prob)
        """
        action_dist = self.forward(obs)
        
        if deterministic:
            action = torch.argmax(action_dist.logits, dim=-1)
        else:
            action = action_dist.sample()
            
        log_prob = action_dist.log_prob(action)
        return action, log_prob
    
    def get_action_and_value(
        self, 
        obs: torch.Tensor, 
        action: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get action, log probability, entropy, and value.
        
        Args:
            obs: Observation tensor
            action: Optional action tensor for computing log_prob
            
        Returns:
            Tuple of (action, log_prob, entropy, value)
        """
        action_dist = self.forward(obs)
        value = self.get_value(obs)
        
        if action is None:
            action = action_dist.sample()
            
        log_prob = action_dist.log_prob(action)
        entropy = action_dist.entropy()
        
        return action, log_prob, entropy, value


class ContinuousActorCriticPolicy(ActorCriticPolicy):
    """Actor-Critic policy for continuous action spaces.
    
    Uses Normal distribution for action sampling.
    Optimized for continuous control tasks.
    """
    
    def __init__(self, actor_network: nn.Module, critic_network: nn.Module, action_dim: int, shared_backbone: nn.Module = None):
        """Initialize continuous actor-critic policy.
        
        Args:
            actor_network: Actor network that outputs action means
            critic_network: Critic network that outputs state values
            shared_backbone: Optional shared backbone network
            action_dim: Dimension of the action space
        """
        super().__init__(actor_network, critic_network, action_dim, shared_backbone)
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))
    
    def to(self, device):
        """Move policy to device."""
        super().to(device)
        self.actor_logstd = self.actor_logstd.to(device)
        return self
    
    def forward(self, obs: torch.Tensor) -> distributions.Normal:
        """Forward pass through actor network.
        
        Args:
            obs: Observation tensor
            
        Returns:
            Normal action distribution
        """
        action_mean = self.actor_forward(obs)
        
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        
        probs = distributions.Normal(action_mean, action_std)
        return probs
    
    def get_action(
        self, 
        obs: torch.Tensor, 
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get action and log probability.
        
        Args:
            obs: Observation tensor
            deterministic: Whether to use deterministic action (mean)
            
        Returns:
            Tuple of (action, log_prob)
        """
        action_dist = self.forward(obs)
        
        if deterministic:
            action = action_dist.mean
        else:
            action = action_dist.sample()
            
        log_prob = action_dist.log_prob(action).sum(axis=-1)
        return action, log_prob
    
    def get_action_and_value(
        self, 
        obs: torch.Tensor, 
        action: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get action, log probability, entropy, and value.
        
        Args:
            obs: Observation tensor
            action: Optional action tensor for computing log_prob
            
        Returns:
            Tuple of (action, log_prob, entropy, value)
        """
        action_dist = self.forward(obs)
        value = self.get_value(obs)
        
        if action is None:
            action = action_dist.sample()
            
        log_prob = action_dist.log_prob(action).sum(1)
        entropy = action_dist.entropy().sum(1)
        
        return action, log_prob, entropy, value