# src/benchrl/policies/pbac_policy.py
"""
PAC-Bayesian Actor-Critic Policy Implementation for BenchRL.

Based on Tasdighi et al. "Deep Exploration with PAC-Bayes"
Implements posterior sampling via multi-head actor selection.
"""

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
import copy
import numpy as np

from benchrl.policies.base import BasePolicy


class ContinuousPBACPolicy(BasePolicy):
    """
    PBAC policy specialized for continuous action spaces.
    
    Matches BenchRL pattern with separate continuous/discrete implementations.
    """
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20
    
    def __init__(
        self, 
        actor_network: nn.Module, 
        critic_network: nn.Module, 
        action_dim: int,
        n_critics: int = 10,
    ):
        super().__init__()
        
        self.actor_network = actor_network
        self.action_dim = action_dim
        self.n_critics = n_critics
        
        # Validate multi-head actor
        if not hasattr(actor_network, 'n_heads'):
            raise ValueError("Actor network must be a multi-head network (MHMLPModule)")
        if not hasattr(actor_network, 'head_names'):
            raise ValueError("Actor network must have 'head_names' attribute")
        
        self.n_actor_heads = actor_network.n_heads
        if n_critics != self.n_actor_heads:
            print(f"Warning: n_critics ({n_critics}) doesn't match "
                  f"actor_network.n_heads ({self.n_actor_heads}). "
                  f"Using n_actor_heads={self.n_actor_heads} for both.")
            self.n_critics = self.n_actor_heads

        # Create critic ensemble
        if isinstance(critic_network, nn.ModuleList):
            self.critic_ensemble = critic_network
        else:
            raise ValueError("Critic network must be a nn.ModuleList")

        # Posterior sampling state
        self.interaction_iter = 0
        self.idx_active_head = 0
        self.is_episode_end = False
        self.training = True
        
    def get_action(
        self, 
        obs: torch.Tensor, 
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Sample action from policy using posterior sampling strategy."""
        # Get actor outputs based on training/eval mode
        mean, log_std = self.forward(obs, deterministic)
        std = log_std.exp()
        
        if deterministic:
            # Use mean for deterministic action
            action = torch.tanh(mean)
            # For deterministic actions, log_prob is not meaningful
            log_prob = torch.zeros_like(action).sum(dim=-1)
        else:
            # Sample using reparameterization trick
            normal = distributions.Normal(mean, std)
            x_t = normal.rsample()  # Reparameterized sampling
            y_t = torch.tanh(x_t)
            action = y_t * self.actor_network.action_scale + self.actor_network.action_bias
            # Compute log probability with tanh correction
            log_prob = normal.log_prob(x_t)
            # Enforcing Action Bound
            log_prob -= torch.log(self.actor_network.action_scale * (1 - y_t.pow(2)) + 1e-6)
            log_prob = log_prob.sum(1, keepdim=True)
            mean = torch.tanh(mean) * self.actor_network.action_scale + self.actor_network.action_bias
        
        return action, log_prob, mean
    
    def forward(self, obs: torch.Tensor, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through actor network."""
        outputs = self._get_actor_outputs(obs, deterministic)
        mean, log_std = outputs.chunk(2, dim=-1)

        # Clamp log_std for numerical stability
        log_std = torch.tanh(log_std)
        log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1)
        return mean, log_std
    
    def _get_actor_outputs(self, obs: torch.Tensor, deterministic: bool) -> torch.Tensor:
        """Get outputs from actor network based on training/eval mode."""
        outputs = self.actor_network(obs)
        
        if self.training and not deterministic:
            # Use currently selected head during training
            head_name = self.actor_network.head_names[self.idx_active_head]
            head_output = outputs[head_name]
            if torch.any(torch.isnan(head_output)):
                breakpoint()
            else:
                pass
            if len(head_output.shape) == 3 and head_output.shape[1] == 1:
                head_output = head_output.squeeze(1)
            return head_output
        else:
            # Average across all heads during evaluation
            head_outputs = []
            for head_name in self.actor_network.head_names:
                output = outputs[head_name]
                if len(output.shape) == 3 and output.shape[1] == 1:
                    output = output.squeeze(1)
                head_outputs.append(output)
            return torch.stack(head_outputs).mean(dim=0)
    
    def update_posterior_sampling(self, sampling_rate: int, global_step: int) -> None:
        """Update the active actor head for posterior sampling."""
        self.interaction_iter += 1
        should_resample = (
            self.is_episode_end or 
            (self.interaction_iter % sampling_rate == 0)
        )
        if should_resample:
            self.idx_active_head = np.random.randint(0, self.n_actor_heads)
            if self.is_episode_end:
                self.is_episode_end = False
    
    def set_episode_status(self, is_end: bool) -> None:
        """Set episode termination status."""
        self.is_episode_end = is_end
    
    def critic_ensemble_parameters(self):
        """Get all critic parameters."""
        return list(self.critic_ensemble.parameters())

    def critic_parameters(self, idx):
        """Get critic parameters for a specific critic."""
        return [param for param in self.critic_ensemble[idx].parameters()]
        
    def actor_parameters(self):
        """Get all actor parameters."""
        return list(self.actor_network.parameters())
    
    def parameters(self):
        """Get all parameters."""
        params = []
        params.extend(self.actor_network.parameters())
        params.extend(self.critic_ensemble.parameters())
        return params
    
    def to(self, device):
        """Move policy to device."""
        self.actor_network.to(device)
        self.critic_ensemble.to(device)
        return self
    
    def _train(self):
        """Set training mode."""
        self.training = True
        self.actor_network.train()
        self.critic_ensemble.train()
        return self
    
    def _eval(self):
        """Set evaluation mode."""
        self.training = False
        self.actor_network.eval()
        self.critic_ensemble.eval()
        return self

    def state_dict(self):
        """Get state dict."""
        return {
            'actor_network': self.actor_network.state_dict(),
            'critic_ensemble': self.critic_ensemble.state_dict()
        }
    
    def load_state_dict(self, state_dict):
        """Load state dict."""
        self.actor_network.load_state_dict(state_dict['actor_network'])
        self.critic_ensemble.load_state_dict(state_dict['critic_ensemble'])
        return self


class DiscretePBACPolicy(BasePolicy):
    """
    PBAC policy specialized for discrete action spaces.
    """
    
    def __init__(
        self, 
        actor_network: nn.Module, 
        critic_network: nn.Module, 
        action_dim: int,
        n_critics: int = 10,
    ):
        super().__init__()
        
        self.actor_network = actor_network
        self.action_dim = action_dim
        self.n_critics = n_critics
        
        # Validate multi-head actor
        if not hasattr(actor_network, 'n_heads'):
            raise ValueError("Actor network must be a multi-head network (MHMLPModule)")
        if not hasattr(actor_network, 'head_names'):
            raise ValueError("Actor network must have 'head_names' attribute")
        
        self.n_actor_heads = actor_network.n_heads
        if n_critics != self.n_actor_heads:
            self.n_critics = self.n_actor_heads

        # Create critic ensemble (expect it to be passed as a ModuleList)
        if isinstance(critic_network, nn.ModuleList):
            self.critic_ensemble = critic_network
        else:
            raise ValueError("Critic network must be a nn.ModuleList")
        
        # Posterior sampling state
        self.interaction_iter = 0
        self.idx_active_head = 0
        self.is_episode_end = False
        self.training = True
    
    def forward(self, obs: torch.Tensor) -> distributions.Categorical:
        """Forward pass through actor network, returns categorical distribution."""
        logits = self._get_actor_outputs(obs, deterministic=False)
        return distributions.Categorical(logits=logits)

    def get_action(
        self,
        obs: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Sample discrete action from policy."""
        logits = self._get_actor_outputs(obs, deterministic)

        # Create categorical distribution
        dist = distributions.Categorical(logits=logits)

        if deterministic:
            action = logits.argmax(dim=-1)
        else:
            action = dist.sample()

        log_prob = dist.log_prob(action)
        action_probs = F.softmax(logits, dim=-1)

        return action.long(), log_prob, action_probs
    
    def _get_actor_outputs(self, obs: torch.Tensor, deterministic: bool) -> torch.Tensor:
        """Get logits from actor network."""
        outputs = self.actor_network(obs)
        
        if self.training and not deterministic:
            # Use selected head
            head_name = self.actor_network.head_names[self.idx_active_head]
            head_output = outputs[head_name]
            if len(head_output.shape) == 3 and head_output.shape[1] == 1:
                head_output = head_output.squeeze(1)
            return head_output
        else:
            # Average across heads
            head_outputs = []
            for head_name in self.actor_network.head_names:
                output = outputs[head_name]
                if len(output.shape) == 3:
                    output = output.squeeze(1)
                head_outputs.append(output)
            return torch.stack(head_outputs).mean(dim=0)
    
    def update_posterior_sampling(self, sampling_rate: int, global_step: int = 0) -> None:
        """Update the active actor head for posterior sampling."""
        self.interaction_iter += 1
        should_resample = (
            self.is_episode_end or 
            (self.interaction_iter % sampling_rate == 0)
        )
        if should_resample:
            self.idx_active_head = np.random.randint(0, self.n_actor_heads)
            if self.is_episode_end:
                self.is_episode_end = False
    
    def set_episode_status(self, is_end: bool) -> None:
        """Set episode termination status."""
        self.is_episode_end = is_end
    
    def critic_ensemble_parameters(self):
        """Get all critic parameters."""
        return list(self.critic_ensemble.parameters())

    def critic_parameters(self, idx):
        """Get critic parameters for a specific critic."""
        return [param for param in self.critic_ensemble[idx].parameters()]

    def actor_parameters(self):
        """Get all actor parameters."""
        return list(self.actor_network.parameters())

    def parameters(self):
        """Get all parameters."""
        params = []
        params.extend(self.actor_network.parameters())
        params.extend(self.critic_ensemble.parameters())
        return params

    def to(self, device):
        """Move policy to device."""
        self.actor_network.to(device)
        self.critic_ensemble.to(device)
        return self
    
    def train(self, mode: bool = True):
        """Set training mode."""
        self.training = mode
        self.actor_network.train(mode)
        self.critic_ensemble.train(mode)
        return self
    
    def eval(self):
        """Set evaluation mode."""
        return self.train(False)
    
    def state_dict(self):
        """Get state dict."""
        return {
            'actor_network': self.actor_network.state_dict(),
            'critic_ensemble': self.critic_ensemble.state_dict()
        }
    
    def load_state_dict(self, state_dict):
        """Load state dict."""
        self.actor_network.load_state_dict(state_dict['actor_network'])
        self.critic_ensemble.load_state_dict(state_dict['critic_ensemble'])
        return self