# src/benchrl/algorithms/pbac.py
import copy
from typing import Dict, Any, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import gymnasium as gym

from omegaconf import OmegaConf

from benchrl.algorithms.base import BaseAlgorithm
from benchrl.policies.pbac_policy import ContinuousPBACPolicy, DiscretePBACPolicy
from benchrl.utils.buffers.replay_buffer import ReplayBuffer
from benchrl.utils._functions import _build_network

class PBAC(BaseAlgorithm):
    """PAC-Bayesian Actor-Critic (PBAC) algorithm.
    
    Implementation based on Tasdighi et al. "Deep Exploration with PAC-Bayes"
    Uses ensemble critics with PAC-Bayesian loss and multi-head actors for exploration.
    """
    
    def __init__(
        self,
        env,
        algo_config: Dict[str, Any],
        device: str = "auto",
        writer: Optional[SummaryWriter] = None,
    ):
        super().__init__(env, algo_config, device, writer)
        
        # Determine action space type
        self.is_discrete = isinstance(self.action_space, gym.spaces.Discrete)
        self.is_continuous = isinstance(self.action_space, gym.spaces.Box)
        
        if not (self.is_discrete or self.is_continuous):
            raise ValueError("PBAC only supports Discrete or Box action spaces")
        
        # Extract config parameters
        self.total_timesteps = algo_config.get('total_timesteps', 1000000)
        self.buffer_size = algo_config.get('buffer_size', 100000)
        self.learning_starts = algo_config.get('learning_starts', 10000)
        self.batch_size = algo_config.get('batch_size', 256)
        self.tau = algo_config.get('tau', 0.005)
        self.gamma = algo_config.get('gamma', 0.99)
        self.train_freq = algo_config.get('train_freq', 1)
        self.gradient_steps = algo_config.get('gradient_steps', 1)
        self.policy_frequency = algo_config.get('policy_frequency', 1)
        
        # PBAC-specific parameters
        self.n_critics = algo_config.get('n_critics', 10)
        self.posterior_sampling_rate = algo_config.get('posterior_sampling_rate', 5)
        self.bootstrap_rate = algo_config.get('bootstrap_rate', 0.05)
        
        # PAC-Bayes loss parameters (matching paper)
        self.prior_variance = algo_config.get('prior_variance', 1.0)
        self.coherence_weight = algo_config.get('coherence_weight', 1.0)  # Weight for coherence term
        self.propagation_weight = algo_config.get('propagation_weight', 1.0)  # Weight for propagation term

        # Learning rates
        self.actor_lr = algo_config.get('actor_lr', 3e-4)
        self.critic_lr = algo_config.get('critic_lr', 3e-4)
        
        # Entropy coefficient
        self.ent_coeff = algo_config.get('ent_coeff', 'auto')
        self.target_entropy = algo_config.get('target_entropy', None)
        
        # Initialize networks
        # Build networks using the new modular architecture
        print("Building networks with modular architecture...")
        modules = self.build_module()
        print(f"Networks: {list(modules.keys())}")
        
        actor_network = modules.get('actor_network', None)
        critic_network = modules.get('critic_network', None)
        
        if actor_network is None or critic_network is None:
            raise ValueError("PBAC requires actor_network and critic_network")
        
        # Build PBAC policy based on action space
        action_dim = self.action_space.n if self.is_discrete else self.action_space.shape[0]
        
        if self.is_discrete:
            self.policy = DiscretePBACPolicy(
                actor_network=actor_network,
                critic_network=critic_network,
                action_dim=action_dim,
                n_critics=self.n_critics,
            )
        else:
            self.policy = ContinuousPBACPolicy(
                actor_network=actor_network,
                critic_network=critic_network,
                action_dim=action_dim,
                n_critics=self.n_critics,
            )
            # action rescaling
            self.policy.actor_network.register_buffer(
                "action_scale",
                torch.tensor(
                    (env.single_action_space.high - env.single_action_space.low) / 2.0,
                    dtype=torch.float32,
                ),
            )
            self.policy.actor_network.register_buffer(
                "action_bias",
                torch.tensor(
                    (env.single_action_space.high + env.single_action_space.low) / 2.0,
                    dtype=torch.float32,
                ),
            )
        
        # Create target policy
        self.target_policy = copy.deepcopy(self.policy)
        self.target_policy._requires_grad(False)
        
        # Move to device
        self.policy.to(self.device)
        self.target_policy.to(self.device)

        # Setup optimizers
        self.actor_optimizer = optim.Adam(
            self.policy.actor_parameters(), 
            lr=self.actor_lr
        )
        self.critic_optimizer = optim.Adam(
            self.policy.critic_ensemble_parameters(), 
            lr=self.critic_lr
        )
        
        # Automatic entropy tuning
        if self.ent_coeff == 'auto':
            if self.target_entropy is None:
                if self.is_discrete:
                    self.target_entropy = -np.log(1.0 / self.action_space.n) * 0.98
                else:
                    self.target_entropy = -np.prod(self.action_space.shape).item()
            
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha = self.log_alpha.exp().item()
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.actor_lr)
        else:
            self.alpha = self.ent_coeff
            self.log_alpha = None
            self.alpha_optimizer = None
        
        # Initialize replay buffer (following SAC pattern)
        self.replay_buffer = ReplayBuffer(
            buffer_size=self.buffer_size,
            observation_space=self.env.single_observation_space,
            action_space=self.env.single_action_space,
            device=self.device,
            n_envs=self.env.num_envs,
            handle_timeout_termination=False,
        )
        
        # Initialize environment state
        self.last_obs = None
        
        # Training state
        self.global_step = 0
        self.policy_update_count = 0
        self.episode_count = 0
        self.actor_losses = []
        self.critic_losses = []
        self.alpha_losses = []
    
    def build_module(self):
        """
        Build models based on configuration bindings.
        
        Supports separate networks: Each role gets its own network
        PBAC specific: builds multi-head actor for posterior sampling
            
        Returns:
            Dictionary mapping roles to model instances
        """
        model_bindings = OmegaConf.to_container(self.algo_config.model_bindings, resolve=True)

        # Remove shared backbone if present
        model_bindings.pop('shared_backbone', None)
        
        # Build role-specific models
        modules = {}
        for role, model_spec in model_bindings.items():
            if not isinstance(model_spec, str):
                raise ValueError(f"Model spec for role '{role}' must be a string refering to module config name, got {type(model_spec)}")
            
            if "critic" in role:
                if self.is_discrete:
                    # Discrete: critics only take observations as input
                    self.in_dim = self.flat_in_dim
                    self.out_dim = self.action_space.n  # Q-value for each action
                else:
                    # Continuous: critics take observations + actions
                    self.in_dim = self.flat_in_dim + self.action_dim
                    self.out_dim = 1  # Single Q-value

                modules[role] = nn.ModuleList([
                    _build_network(model_spec, self.in_dim, self.out_dim) for _ in range(self.n_critics)
                ])

            elif "actor" in role:
                self.in_dim = self.flat_in_dim
                if self.is_discrete:
                    # Discrete: output logits for each action
                    self.out_dim = [self.action_space.n] * self.n_critics
                else:
                    # Continuous: output mean and log_std for each action
                    self.out_dim = [2 * self.action_dim] * self.n_critics

                modules[role] = _build_network(model_spec, self.in_dim, self.out_dim, n_heads=self.n_critics, layer_normalization=True)
            
        return modules
    
    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        """Select action using current policy."""
        obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            action, _, _ = self.policy.get_action(obs_tensor, deterministic=deterministic)
        
        return action.cpu().numpy()[0]

    def collect_rollouts(self, num_steps: int = 1) -> Dict[str, float]:
        """Collect experience for replay buffer following BenchRL SAC pattern."""
        if self.last_obs is None:
            self.last_obs, _ = self.env.reset()
        
        episode_returns = []
        episode_lengths = []
        
        # Pre-allocate tensors to avoid repeated allocation
        if not hasattr(self, '_obs_tensor'):
            self._obs_tensor = torch.zeros((self.env.num_envs, *self.env.single_observation_space.shape), 
                                         dtype=torch.float32, device=self.device)
        
        for _ in range(num_steps):
            self.global_step += self.env.num_envs
            
            # Copy observations to pre-allocated tensor
            self._obs_tensor.copy_(torch.from_numpy(self.last_obs))
            
            with torch.no_grad():
                if self.global_step < self.learning_starts:
                    # Random action sampling during initial exploration
                    if self.is_discrete:
                        actions = np.random.randint(0, self.action_space.n, size=(self.env.num_envs,))
                    else:
                        actions = np.random.uniform(
                            self.action_space.low, self.action_space.high, 
                            size=(self.env.num_envs, self.action_space.shape[0])
                        )
                else:
                    # Use policy for action selection
                    actions, _, _ = self.policy.get_action(self._obs_tensor, deterministic=False)
                    actions = actions.cpu().numpy()
            
            # Take environment step
            next_obs, rewards, terminations, truncations, infos = self.env.step(actions)
            
            # Track episode metrics
            rollout_episodes = 0
            if "final_info" in infos:
                eps_final_info = infos["final_info"]['episode']
                rollout_episodes = len(eps_final_info["r"])
                self.episode_count += rollout_episodes
                for i in range(len(eps_final_info['r'])):
                    episode_returns.append(eps_final_info["r"][i])
                    episode_lengths.append(eps_final_info["l"][i])
                    # Signal episode end for posterior sampling
                    self.policy.set_episode_status(True)
                    break
            
            # Handle timeout termination
            real_next_obs = next_obs.copy()
            for idx, trunc in enumerate(truncations):
                if trunc:
                    real_next_obs[idx] = infos["final_obs"][idx]

            # Store transitions in replay buffer
            self.replay_buffer.add(
                obs=self.last_obs,
                next_obs=real_next_obs,
                action=actions,
                reward=rewards,
                done=terminations,
                infos=infos
            )
            
            self.last_obs = next_obs
            
        # Return rollout metrics
        metrics = {}
        if episode_returns:
            metrics.update({
                'rollout/episodic_return': episode_returns,
                'rollout/episodic_length': episode_lengths,
                'rollout/episodes': rollout_episodes
            })
        
        return metrics
    
    def train_step(self) -> Dict[str, float]:
        """Execute one training step following BenchRL pattern."""
        # Collect experience
        rollout_metrics = self.collect_rollouts(self.train_freq)
        
        # Skip training if not enough samples
        if self.global_step < self.learning_starts:
            return rollout_metrics
        
        
        # Prepare training metrics
        training_metrics = {}
        
        for _ in range(self.gradient_steps):
            # Sample batch from replay buffer
            batch = self.replay_buffer.sample(self.batch_size)
            
            # Update critics with PAC-Bayes loss
            critic_loss = self._update_critics_pac_bayes(batch)
            training_metrics['train/critic_loss'] = critic_loss
            
            # Update actor (less frequently)
            if self.global_step % self.policy_frequency == 0:
                for _ in range(self.policy_frequency):
                    actor_loss = self._update_actor(batch)
                training_metrics['train/actor_loss'] = actor_loss

                # Update alpha if auto-tuning
                if self.log_alpha is not None:
                    alpha_loss = self._update_alpha(batch)
                    training_metrics['train/alpha_loss'] = alpha_loss
                    training_metrics['train/alpha'] = self.alpha
                
                # Update target networks
                self._update_targets()
                
                # Update posterior sampling for actor heads
                self.policy.update_posterior_sampling(self.posterior_sampling_rate, global_step=self.global_step)

                self.policy_update_count += 1
        
        # Combine rollout and training metrics
        all_metrics = {**rollout_metrics, **training_metrics}
        return all_metrics
    
    def _update_critics_pac_bayes(self, batch) -> float:
        """Correct implementation matching ObjectRL exactly."""

        current_q_values = []
        for k in range(self.n_critics):
            if self.is_discrete:
                # For discrete: critic takes only observations, returns Q-values for all actions
                q_all = self.policy.critic_ensemble[k](batch.observations)
                q_k = q_all.gather(1, batch.actions.long()).squeeze()
            else:
                # For continuous: critic takes observations + actions
                q_k = self.policy.critic_ensemble[k](
                    torch.cat([batch.observations, batch.actions], dim=1)
                ).squeeze()
            current_q_values.append(q_k)

        q = torch.stack(current_q_values, dim=0)

        # Get targets
        with torch.no_grad():
            next_actions, next_log_probs, next_action_probs = self.target_policy.get_action(
                batch.next_observations, deterministic=False
            )

            target_q_values = []
            for k in range(self.n_critics):
                if self.is_discrete:
                    # For discrete: get Q-values for all actions, then weight by action probabilities
                    q_all = self.target_policy.critic_ensemble[k](batch.next_observations)
                    # Apply entropy regularization and compute expected Q-value
                    entropy_regularized_q = q_all - self.alpha * next_log_probs.unsqueeze(1)
                    q_k = (next_action_probs * entropy_regularized_q).sum(dim=1)
                else:
                    # For continuous: use sampled next actions
                    q_k = self.target_policy.critic_ensemble[k](
                        torch.cat([batch.next_observations, next_actions], dim=1)
                    ).squeeze()
                target_q_values.append(q_k)

            target_q_stack = torch.stack(target_q_values, dim=0)  # [n_critics, batch_size]

            # Apply entropy regularization only for continuous actions
            # (discrete actions already have entropy regularization applied above)
            if not self.is_discrete:
                entropy_term = self.alpha * next_log_probs.squeeze()
                target_q_stack = target_q_stack - entropy_term
            
            rewards_expanded = batch.rewards.squeeze()
            dones_expanded = batch.dones.squeeze()
            y = rewards_expanded + (1 - dones_expanded) * self.gamma * target_q_stack

        mu_0 = y.mean(dim=0)
        sig2_0 = self.prior_variance

        bootstrap_mask = (torch.rand_like(q) >= self.bootstrap_rate).float()
        
        sig2 = (q * bootstrap_mask).var(dim=0).clamp(1e-8, None)
        logsig2 = sig2.log()

        err_0 = (q - mu_0.unsqueeze(0)) * bootstrap_mask
        
       
        term1 = -0.5 * logsig2
        term2 = 0.5 * (err_0.pow(2)).mean(dim=0) / sig2_0
        kl_term = term1 + term2

        var_offset = -self.gamma**2 * logsig2
        emp_loss = ((q - y) * bootstrap_mask).pow(2)
        
        q_loss = emp_loss + kl_term.unsqueeze(0) + var_offset.unsqueeze(0)
        
        total_loss = q_loss.mean()
        
        self.critic_optimizer.zero_grad()
        total_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.policy.critic_ensemble_parameters(), 0.5)
        self.critic_optimizer.step()
        
        return total_loss.item()
    
    def _update_actor(self, batch) -> float:
        """Update actor using ensemble of critics."""
        
        # Get actions from current policy
        actions, log_probs, _ = self.policy.get_action(batch.observations, deterministic=False)
        
        # Get Q-values from ALL critics and average them
        q_values = []
        for k in range(self.n_critics):
            if self.is_discrete:
                # For discrete: critic takes only observations, returns Q-values for all actions
                q_all = self.policy.critic_ensemble[k](batch.observations)
                q_k = q_all.gather(1, actions.long().unsqueeze(-1)).squeeze()
            else:
                # For continuous: critic takes observations + actions
                q_k = self.policy.critic_ensemble[k](
                    torch.cat([batch.observations, actions], dim=1)
                ).squeeze()
            q_values.append(q_k)
        
        # Average Q-values across ensemble
        q_value_mean = torch.stack(q_values, dim=0).mean(dim=0)
        
        # Actor loss: maximize average Q-value minus entropy
        actor_loss = (self.alpha * log_probs.squeeze() - q_value_mean).mean()
        
        # Update with gradient clipping
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.policy.actor_parameters(), 0.1)
        self.actor_optimizer.step()
        
        return actor_loss.item()
    
    def _update_alpha(self, batch) -> float:
        """Update temperature parameter alpha."""
        with torch.no_grad():
            _, log_probs, _ = self.policy.get_action(batch.observations)

        alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy)).mean()
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        
        self.alpha = self.log_alpha.exp().item()
        self.alpha_losses.append(alpha_loss.item())
        
        return alpha_loss.item()
    
    def _update_targets(self):
        """Update target networks using polyak averaging."""
        for param, target_param in zip(
            self.policy.parameters(), 
            self.target_policy.parameters()
        ):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )
    
    def save(self, path: str) -> None:
        """Save algorithm state."""
        state = {
            'policy_state_dict': self.policy.state_dict(),
            'target_policy_state_dict': self.target_policy.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'global_step': self.global_step,
            'policy_update_count': self.policy_update_count,
        }
        
        if self.log_alpha is not None:
            state['log_alpha'] = self.log_alpha
            state['alpha_optimizer_state_dict'] = self.alpha_optimizer.state_dict()
            
        torch.save(state, path)
    
    def load(self, path: str) -> None:
        """Load algorithm state."""
        state = torch.load(path)
        
        self.policy.load_state_dict(state['policy_state_dict'])
        self.target_policy.load_state_dict(state['target_policy_state_dict'])
        self.actor_optimizer.load_state_dict(state['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(state['critic_optimizer_state_dict'])
        self.global_step = state['global_step']
        self.policy_update_count = state['policy_update_count']
        
        if self.log_alpha is not None and 'log_alpha' in state:
            self.log_alpha = state['log_alpha']
            self.alpha_optimizer.load_state_dict(state['alpha_optimizer_state_dict'])
            self.alpha = self.log_alpha.exp().item()