from abc import abstractmethod
from typing import Tuple, Optional
import copy
import torch
import torch.nn as nn
import torch.distributions as distributions
from .actor_critic_policy import ActorCriticPolicy


class SACPolicy(ActorCriticPolicy):
    """SAC Actor-Critic policy with twin Q-networks.
    
    Extends ActorCriticPolicy to support off-policy algorithms like SAC.
    Uses twin Q-networks and separate target networks for stability.
    """
    
    def __init__(self, actor_network: nn.Module, critic_network: nn.Module, action_dim: int):
        """Initialize SAC actor-critic policy.
        
        Args:
            actor_network: Actor network that outputs action distribution parameters
            critic_network: Critic network template for creating twin Q-networks
            action_dim: Dimension of the action space
        """
        super().__init__(actor_network, critic_network, action_dim)
        
        # Create second critic network identical to first
        self.critic2_network = copy.deepcopy(self.critic_network)
    
    def get_q_values(self, obs: torch.Tensor, action: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get Q-values from both critic networks.
        
        Args:
            obs: Observation tensor
            action: Action tensor
            
        Returns:
            Tuple of Q-values from critic1 and critic2
        """
        if action is not None:
            # Concatenate obs and action for critics
            critic_input = torch.cat([obs, action], dim=-1)
        else:
            critic_input = obs
        q1 = self.critic_network(critic_input)
        q2 = self.critic2_network(critic_input)
        
        return q1.flatten(), q2.flatten()
    
    def parameters(self):
        """Get all policy parameters including twin critics."""
        params = []
        params.extend(self.actor_network.parameters())
        params.extend(self.critic_network.parameters())
        params.extend(self.critic2_network.parameters())
        return params
    
    def critic_parameters(self):
        """Get only critic parameters (for separate optimizer)."""
        return list(self.critic_network.parameters()) + list(self.critic2_network.parameters())
    
    def actor_parameters(self):
        """Get only actor parameters (for separate optimizer)."""
        params = list(self.actor_network.parameters())
        return params
    
    def to(self, device):
        """Move policy to device."""
        super().to(device)
        self.critic2_network.to(device)
        return self
    
    def _train(self):
        """Set policy to training mode."""
        super()._train()
        self.critic2_network.train()
    
    def _eval(self):
        """Set policy to evaluation mode."""
        super()._eval()
        self.critic2_network.eval()
    
    def _compile(self, **kwargs):
        """Compile policy networks for performance."""
        super()._compile(**kwargs)
        self.critic2_network = torch.compile(self.critic2_network, **kwargs)
        return self
    
    def state_dict(self):
        """Get state dict of policy."""
        state = super().state_dict()
        state['critic1_network'] = state.pop('critic_network')
        state['critic2_network'] = self.critic2_network.state_dict()
        return state
    
    def load_state_dict(self, state_dict):
        """Load state dict into policy."""
        if 'critic1_network' not in state_dict or 'critic2_network' not in state_dict:
            raise ValueError("State dict must contain critic1_network and critic2_network")
        
        self.actor_network.load_state_dict(state_dict['actor_network'])
        self.critic_network.load_state_dict(state_dict['critic1_network'])
        self.critic2_network.load_state_dict(state_dict['critic2_network'])
        return self


class DiscreteSACPolicy(SACPolicy):
    """SAC Actor-Critic policy for discrete action spaces.
    
    Uses Categorical distribution with Gumbel-Softmax for reparameterization.
    """
    
    def forward(self, obs: torch.Tensor) -> distributions.Categorical:
        """Forward pass through actor network."""
        logits = self.actor_forward(obs)
        return distributions.Categorical(logits=logits)
    
    def get_action(
        self, 
        obs: torch.Tensor, 
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get action and log probability for discrete SAC."""
        action_dist = self.forward(obs)
        
        if deterministic:
            action = torch.argmax(action_dist.logits, dim=-1)
            log_prob = action_dist.log_prob(action)
        else:
            action = action_dist.sample()
            log_prob = action_dist.log_prob(action)
        
        return action, log_prob, action_dist.probs

    def get_min_q_value(self, obs: torch.Tensor) -> torch.Tensor:
        """Get minimum Q-value from twin critics (for SAC actor loss)."""
        q1, q2 = self.get_q_values(obs)
        return torch.min(q1, q2)
    
    def get_log_prob(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Get log probability of actions under the current policy."""
        action_dist = self.forward(obs)
        return action_dist.log_prob(actions)


class ContinuousSACPolicy(SACPolicy):
    """SAC Actor-Critic policy for continuous action spaces.
    
    Uses reparameterization trick with tanh squashing for bounded actions.
    """
    
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20
    
    def __init__(self, actor_network: nn.Module, critic_network: nn.Module, action_dim: int):
        """Initialize continuous SAC policy."""
        
        # check if actor network outputs mean and log_std (have two heads)
        # Validate required attributes exist
        if not hasattr(actor_network, 'head_names'):
            raise ValueError("Actor network must have 'head_names' attribute")
        if not hasattr(actor_network, '_output_layers'):
            raise ValueError("Actor network must have '_output_layers' attribute")
        if not hasattr(actor_network, 'output_dims'):
            raise ValueError("Actor network must have 'output_dims' attribute")

        # Validate head_names structure
        if len(actor_network.head_names) != 2:
            raise ValueError(f"Actor network must have exactly 2 heads, got {len(actor_network.head_names)}")
           
        if actor_network.output_dims[0] != action_dim or actor_network.output_dims[1] != action_dim:
            raise ValueError(f"Actor network output dimensions do not match action dimension {action_dim}")

        # ensure actor network has two heads for mean and log_std
        head_0 = actor_network.head_names[0]
        head_1 = actor_network.head_names[1]

        # Validate that heads exist in _output_layers
        if head_0 not in actor_network._output_layers:
            raise ValueError(f"Head '{head_0}' not found in _output_layers")
        if head_1 not in actor_network._output_layers:
            raise ValueError(f"Head '{head_1}' not found in _output_layers")
  
        actor_network._output_layers["mean"] = actor_network._output_layers.pop(head_0)
        actor_network._output_layers["log_std"] = actor_network._output_layers.pop(head_1)
        actor_network.head_names = ['mean', 'log_std']
        
        super().__init__(actor_network, critic_network, action_dim)
    
    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through actor network.
        
        Returns:
            Tuple of (action_mean, action_log_std)
        """
        actor_output = self.actor_forward(obs)
        
        log_std = torch.tanh(actor_output['log_std'])
        # Clamp log_std for numerical stability
        log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1)

        return actor_output['mean'], log_std
    
    def get_action(
        self, 
        obs: torch.Tensor, 
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get action and log probability using reparameterization trick."""
        mean, log_std = self.forward(obs)
        std = log_std.exp()
        
        if deterministic:
            # Use mean for deterministic action
            action = torch.tanh(mean)
            # For deterministic actions, log_prob is not meaningful
            log_prob = torch.zeros_like(action).sum(dim=-1)
        else:
            # Sample using reparameterization trick
            normal = distributions.Normal(mean, std)
            x_t = normal.rsample()  # Reparameterized sampling
            y_t = torch.tanh(x_t)
            action = y_t * self.actor_network.action_scale + self.actor_network.action_bias
            # Compute log probability with tanh correction
            log_prob = normal.log_prob(x_t)
            # Enforcing Action Bound
            log_prob -= torch.log(self.actor_network.action_scale * (1 - y_t.pow(2)) + 1e-6)
            log_prob = log_prob.sum(1, keepdim=True)
            mean = torch.tanh(mean) * self.actor_network.action_scale + self.actor_network.action_bias

        return action, log_prob, mean

    def get_min_q_value(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """Get minimum Q-value from twin critics (for SAC actor loss)."""
        q1, q2 = self.get_q_values(obs, action)
        return torch.min(q1, q2)

    def get_log_prob(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Get log probability of actions under the actor network."""
        mean, log_std = self.forward(obs)
        std = log_std.exp()
        
        raw_action = (actions - self.actor_network.action_bias) / self.actor_network.action_scale
        
        raw_action = torch.clamp(raw_action, -0.999, 0.999)
        
        x_t = torch.atanh(raw_action)
        normal = distributions.Normal(mean, std)
        log_prob = normal.log_prob(x_t)
        
        # Stable tanh correction
        log_prob -= torch.log(self.actor_network.action_scale * (1 - raw_action.pow(2)) + 1e-8)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        
        return log_prob